1use std::borrow::Cow;
5use std::collections::{HashMap, HashSet};
6use std::sync::{Arc, Mutex};
7
8use arrow::array::AsArray;
9use arrow::compute::{concat_batches, TakeOptions};
10use arrow::datatypes::UInt64Type;
11use arrow_array::{Array, UInt32Array};
12use arrow_array::{RecordBatch, UInt64Array};
13use arrow_schema::{Schema as ArrowSchema, SchemaRef};
14use datafusion::common::Statistics;
15use datafusion::error::{DataFusionError, Result};
16use datafusion::physical_plan::metrics::{
17 BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricValue, MetricsSet,
18};
19use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
20use datafusion::physical_plan::{
21 DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream,
22};
23use datafusion_physical_expr::EquivalenceProperties;
24use futures::stream::{FuturesOrdered, Stream, StreamExt, TryStreamExt};
25use futures::FutureExt;
26use lance_arrow::RecordBatchExt;
27use lance_core::datatypes::{Field, OnMissing, Projection};
28use lance_core::error::{DataFusionResult, LanceOptionExt};
29use lance_core::utils::address::RowAddress;
30use lance_core::utils::tokio::get_num_compute_intensive_cpus;
31use lance_core::{ROW_ADDR, ROW_ID};
32use lance_io::scheduler::{ScanScheduler, SchedulerConfig};
33
34use crate::dataset::fragment::{FragReadConfig, FragmentReader};
35use crate::dataset::rowids::get_row_id_index;
36use crate::dataset::Dataset;
37use crate::datatypes::Schema;
38
39use super::utils::IoMetrics;
40
41#[derive(Debug, Clone)]
42struct TakeStreamMetrics {
43 baseline_metrics: BaselineMetrics,
44 batches_processed: Count,
45 io_metrics: IoMetrics,
46}
47
48impl TakeStreamMetrics {
49 fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
50 let batches_processed = Count::new();
51 MetricBuilder::new(metrics)
52 .with_partition(partition)
53 .build(MetricValue::Count {
54 name: Cow::Borrowed("batches_processed"),
55 count: batches_processed.clone(),
56 });
57 Self {
58 baseline_metrics: BaselineMetrics::new(metrics, partition),
59 batches_processed,
60 io_metrics: IoMetrics::new(metrics, partition),
61 }
62 }
63}
64
65struct TakeStream {
66 dataset: Arc<Dataset>,
68 fields_to_take: Arc<Schema>,
70 output_schema: SchemaRef,
73 readers_cache: Arc<Mutex<HashMap<u32, Arc<FragmentReader>>>>,
77 scan_scheduler: Arc<ScanScheduler>,
79 metrics: TakeStreamMetrics,
81}
82
83impl TakeStream {
84 fn new(
85 dataset: Arc<Dataset>,
86 fields_to_take: Arc<Schema>,
87 output_schema: SchemaRef,
88 scan_scheduler: Arc<ScanScheduler>,
89 metrics: &ExecutionPlanMetricsSet,
90 partition: usize,
91 ) -> Self {
92 Self {
93 dataset,
94 fields_to_take,
95 output_schema,
96 readers_cache: Arc::new(Mutex::new(HashMap::new())),
97 scan_scheduler,
98 metrics: TakeStreamMetrics::new(metrics, partition),
99 }
100 }
101
102 async fn do_open_reader(&self, fragment_id: u32) -> DataFusionResult<Arc<FragmentReader>> {
103 let fragment = self
104 .dataset
105 .get_fragment(fragment_id as usize)
106 .ok_or_else(|| {
107 DataFusionError::Execution(format!("The input to a take operation specified fragment id {} but this fragment does not exist in the dataset", fragment_id))
108 })?;
109
110 let reader = Arc::new(
111 fragment
112 .open(
113 &self.fields_to_take,
114 FragReadConfig::default().with_scan_scheduler(self.scan_scheduler.clone()),
115 )
116 .await?,
117 );
118
119 let mut readers = self.readers_cache.lock().unwrap();
120 readers.insert(fragment_id, reader.clone());
121 Ok(reader)
122 }
123
124 async fn open_reader(&self, fragment_id: u32) -> DataFusionResult<Arc<FragmentReader>> {
125 if let Some(reader) = self
126 .readers_cache
127 .lock()
128 .unwrap()
129 .get(&fragment_id)
130 .cloned()
131 {
132 return Ok(reader);
133 }
134
135 self.do_open_reader(fragment_id).await
136 }
137
138 async fn get_row_addrs(&self, batch: &RecordBatch) -> Result<Arc<dyn Array>> {
139 if let Some(row_addr_array) = batch.column_by_name(ROW_ADDR) {
140 Ok(row_addr_array.clone())
141 } else {
142 let row_id_array = batch.column_by_name(ROW_ID).expect_ok()?;
143
144 if let Some(row_id_index) = get_row_id_index(&self.dataset).await? {
145 let row_id_array = row_id_array.as_primitive::<UInt64Type>();
146 let addresses = row_id_array
147 .values()
148 .iter()
149 .filter_map(|id| row_id_index.get(*id).map(|address| address.into()))
150 .collect::<Vec<u64>>();
151 Ok(Arc::new(UInt64Array::from(addresses)))
152 } else {
153 Ok(row_id_array.clone())
154 }
155 }
156 }
157
158 async fn map_batch(
159 self: Arc<Self>,
160 batch: RecordBatch,
161 batch_number: u32,
162 ) -> DataFusionResult<RecordBatch> {
163 let compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer();
164 let row_addrs_arr = self.get_row_addrs(&batch).await?;
165 let row_addrs = row_addrs_arr.as_primitive::<UInt64Type>();
166
167 debug_assert!(
168 row_addrs.null_count() == 0,
169 "{} nulls in row addresses",
170 row_addrs.null_count()
171 );
172 let is_sorted = row_addrs.values().is_sorted();
174
175 let sorted_addrs: Arc<dyn Array>;
176 let (sorted_addrs, permutation) = if is_sorted {
177 (row_addrs, None)
178 } else {
179 let permutation = arrow::compute::sort_to_indices(&row_addrs_arr, None, None).unwrap();
180 sorted_addrs = arrow::compute::take(
181 &row_addrs_arr,
182 &permutation,
183 Some(TakeOptions {
184 check_bounds: false,
185 }),
186 )
187 .unwrap();
188 let mut inverse_permutation = vec![0; permutation.len()];
190 for (i, p) in permutation.values().iter().enumerate() {
191 inverse_permutation[*p as usize] = i as u32;
192 }
193 (
194 sorted_addrs.as_primitive::<UInt64Type>(),
195 Some(UInt32Array::from(inverse_permutation)),
196 )
197 };
198
199 let mut futures = FuturesOrdered::new();
200 let mut current_offsets = Vec::new();
201 let mut current_fragment_id = None;
202
203 for row_addr in sorted_addrs.values() {
204 let addr = RowAddress::new_from_u64(*row_addr);
205
206 if Some(addr.fragment_id()) != current_fragment_id {
207 if let Some(fragment_id) = current_fragment_id {
209 let reader = self.open_reader(fragment_id).await?;
210 let offsets = std::mem::take(&mut current_offsets);
211 futures.push_back(
212 async move { reader.take_as_batch(&offsets, Some(batch_number)).await }
213 .boxed(),
214 );
215 }
216 current_fragment_id = Some(addr.fragment_id());
217 }
218
219 current_offsets.push(addr.row_offset());
220 }
221
222 if let Some(fragment_id) = current_fragment_id {
224 let reader = self.open_reader(fragment_id).await?;
225 futures.push_back(
226 async move {
227 reader
228 .take_as_batch(¤t_offsets, Some(batch_number))
229 .await
230 }
231 .boxed(),
232 );
233 }
234
235 drop(compute_timer);
237
238 let batches = futures.try_collect::<Vec<_>>().await?;
239
240 if batches.is_empty() {
241 return Ok(RecordBatch::new_empty(self.output_schema.clone()));
242 }
243
244 let _compute_timer = self.metrics.baseline_metrics.elapsed_compute().timer();
245 let schema = batches.first().expect_ok()?.schema();
246 let mut new_data = concat_batches(&schema, batches.iter())?;
247
248 if let Some(permutation) = permutation {
250 new_data = arrow_select::take::take_record_batch(&new_data, &permutation).unwrap();
251 }
252
253 self.metrics
254 .baseline_metrics
255 .record_output(new_data.num_rows());
256 self.metrics.batches_processed.add(1);
257 Ok(batch.merge_with_schema(&new_data, self.output_schema.as_ref())?)
258 }
259
260 fn apply<S: Stream<Item = Result<RecordBatch>> + Send + 'static>(
261 self: Arc<Self>,
262 input: S,
263 ) -> impl Stream<Item = Result<RecordBatch>> {
264 let scan_scheduler = self.scan_scheduler.clone();
265 let metrics = self.metrics.clone();
266 let batches = input
267 .enumerate()
268 .map(move |(batch_index, batch)| {
269 let batch = batch?;
270 let this = self.clone();
271 Ok(
272 tokio::task::spawn(this.map_batch(batch, batch_index as u32))
273 .map(|res| res.unwrap()),
274 )
275 })
276 .boxed();
277 batches
278 .inspect_ok(move |_| metrics.io_metrics.record(&scan_scheduler))
279 .try_buffered(get_num_compute_intensive_cpus())
280 }
281}
282
283#[derive(Debug)]
284pub struct TakeExec {
285 dataset: Arc<Dataset>,
287 output_projection: Projection,
292 schema_to_take: Arc<Schema>,
294 output_schema: SchemaRef,
296 input: Arc<dyn ExecutionPlan>,
297 properties: PlanProperties,
298 metrics: ExecutionPlanMetricsSet,
299}
300
301impl DisplayAs for TakeExec {
302 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
303 let extra_fields = self
304 .schema_to_take
305 .fields
306 .iter()
307 .map(|f| f.name.clone())
308 .collect::<HashSet<_>>();
309 let columns = self
310 .output_schema
311 .fields
312 .iter()
313 .map(|f| {
314 let name = f.name();
315 if extra_fields.contains(name) {
316 format!("({})", name)
317 } else {
318 name.clone()
319 }
320 })
321 .collect::<Vec<_>>()
322 .join(", ");
323 match t {
324 DisplayFormatType::Default | DisplayFormatType::Verbose => {
325 write!(f, "Take: columns={:?}", columns)
326 }
327 DisplayFormatType::TreeRender => {
328 write!(f, "Take\ncolumns={:?}", columns)
329 }
330 }
331 }
332}
333
334impl TakeExec {
335 pub fn try_new(
343 dataset: Arc<Dataset>,
344 input: Arc<dyn ExecutionPlan>,
345 projection: Projection,
346 ) -> Result<Option<Self>> {
347 let original_projection = projection.clone();
348 let projection =
349 projection.subtract_arrow_schema(input.schema().as_ref(), OnMissing::Ignore)?;
350 if !projection.has_data_fields() {
351 return Ok(None);
352 }
353
354 if input.schema().column_with_name(ROW_ADDR).is_none()
356 && input.schema().column_with_name(ROW_ID).is_none()
357 {
358 return Err(DataFusionError::Plan(format!(
359 "TakeExec requires the input plan to have a column named '{}' or '{}'",
360 ROW_ADDR, ROW_ID
361 )));
362 }
363
364 assert!(
366 !projection.with_row_id && !projection.with_row_addr,
367 "Take should not be used to insert row_id / row_addr: {:#?}",
368 projection
369 );
370
371 let output_schema = Arc::new(Self::calculate_output_schema(
372 dataset.schema(),
373 &input.schema(),
374 &projection,
375 ));
376 let output_arrow = Arc::new(ArrowSchema::from(output_schema.as_ref()));
377 let properties = input
378 .properties()
379 .clone()
380 .with_eq_properties(EquivalenceProperties::new(output_arrow.clone()));
381
382 Ok(Some(Self {
383 dataset,
384 output_projection: original_projection,
385 schema_to_take: projection.into_schema_ref(),
386 input,
387 output_schema: output_arrow,
388 properties,
389 metrics: ExecutionPlanMetricsSet::new(),
390 }))
391 }
392
393 fn calculate_output_schema(
403 dataset_schema: &Schema,
404 input_schema: &ArrowSchema,
405 projection: &Projection,
406 ) -> Schema {
407 let mut top_level_fields_added = HashSet::with_capacity(input_schema.fields.len());
410 let projected_schema = projection.to_schema();
411
412 let mut output_fields =
413 Vec::with_capacity(input_schema.fields.len() + projected_schema.fields.len());
414 output_fields.extend(input_schema.fields.iter().map(|f| {
416 let f = Field::try_from(f.as_ref()).unwrap();
417 if let Some(ds_field) = dataset_schema.field(&f.name) {
418 top_level_fields_added.insert(ds_field.id);
419 if let Some(projected_field) = ds_field.apply_projection(projection) {
421 f.merge_with_reference(&projected_field, ds_field)
422 } else {
423 f
425 }
426 } else {
427 f
429 }
430 }));
431
432 output_fields.extend(
435 projected_schema
436 .fields
437 .into_iter()
438 .filter(|f| !top_level_fields_added.contains(&f.id)),
439 );
440 Schema {
441 fields: output_fields,
442 metadata: dataset_schema.metadata.clone(),
443 }
444 }
445
446 pub fn dataset(&self) -> &Arc<Dataset> {
448 &self.dataset
449 }
450}
451
452impl ExecutionPlan for TakeExec {
453 fn name(&self) -> &str {
454 "TakeExec"
455 }
456
457 fn as_any(&self) -> &dyn std::any::Any {
458 self
459 }
460
461 fn schema(&self) -> SchemaRef {
462 self.output_schema.clone()
463 }
464
465 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
466 vec![&self.input]
467 }
468
469 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
470 vec![false]
475 }
476
477 fn with_new_children(
479 self: Arc<Self>,
480 children: Vec<Arc<dyn ExecutionPlan>>,
481 ) -> Result<Arc<dyn ExecutionPlan>> {
482 if children.len() != 1 {
483 return Err(DataFusionError::Internal(
484 "TakeExec wrong number of children".to_string(),
485 ));
486 }
487
488 let projection = self.output_projection.clone();
489
490 let plan = Self::try_new(self.dataset.clone(), children[0].clone(), projection)?;
491
492 if let Some(plan) = plan {
493 Ok(Arc::new(plan))
494 } else {
495 Ok(children[0].clone())
497 }
498 }
499
500 fn execute(
501 &self,
502 partition: usize,
503 context: Arc<datafusion::execution::context::TaskContext>,
504 ) -> Result<SendableRecordBatchStream> {
505 let input_stream = self.input.execute(partition, context)?;
506 let dataset = self.dataset.clone();
507 let schema_to_take = self.schema_to_take.clone();
508 let output_schema = self.output_schema.clone();
509 let metrics = self.metrics.clone();
510
511 let lazy_take_stream = futures::stream::once(async move {
515 let obj_store = dataset.object_store.clone();
516 let scheduler_config = SchedulerConfig::max_bandwidth(&obj_store);
517 let scan_scheduler = ScanScheduler::new(obj_store, scheduler_config);
518
519 let take_stream = Arc::new(TakeStream::new(
520 dataset,
521 schema_to_take,
522 output_schema,
523 scan_scheduler,
524 &metrics,
525 partition,
526 ));
527 take_stream.apply(input_stream)
528 });
529 let output_schema = self.output_schema.clone();
530 Ok(Box::pin(RecordBatchStreamAdapter::new(
531 output_schema,
532 lazy_take_stream.flatten(),
533 )))
534 }
535
536 fn metrics(&self) -> Option<MetricsSet> {
537 Some(self.metrics.clone_inner())
538 }
539
540 fn partition_statistics(
541 &self,
542 partition: Option<usize>,
543 ) -> Result<datafusion::physical_plan::Statistics> {
544 Ok(Statistics {
545 num_rows: self.input.partition_statistics(partition)?.num_rows,
546 ..Statistics::new_unknown(self.schema().as_ref())
547 })
548 }
549
550 fn properties(&self) -> &PlanProperties {
551 &self.properties
552 }
553
554 fn supports_limit_pushdown(&self) -> bool {
555 true
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 use arrow_array::{
564 ArrayRef, Float32Array, Int32Array, RecordBatchIterator, StringArray, StructArray,
565 };
566 use arrow_schema::{DataType, Field, Fields};
567 use datafusion::execution::TaskContext;
568 use lance_arrow::SchemaExt;
569 use lance_core::utils::tempfile::TempStrDir;
570 use lance_core::{datatypes::OnMissing, ROW_ID};
571 use lance_datafusion::{datagen::DatafusionDatagenExt, exec::OneShotExec, utils::MetricsExt};
572 use lance_datagen::{BatchCount, RowCount};
573 use rstest::rstest;
574
575 use crate::{
576 dataset::WriteParams,
577 io::exec::{LanceScanConfig, LanceScanExec},
578 utils::test::NoContextTestFixture,
579 };
580
581 struct TestFixture {
582 dataset: Arc<Dataset>,
583 _tmp_dir_guard: TempStrDir,
584 }
585
586 async fn test_fixture() -> TestFixture {
587 let struct_fields = Fields::from(vec![
588 Arc::new(Field::new("x", DataType::Int32, false)),
589 Arc::new(Field::new("y", DataType::Int32, false)),
590 ]);
591
592 let schema = Arc::new(ArrowSchema::new(vec![
593 Field::new("i", DataType::Int32, false),
594 Field::new("f", DataType::Float32, false),
595 Field::new("s", DataType::Utf8, false),
596 Field::new("struct", DataType::Struct(struct_fields.clone()), false),
597 ]));
598
599 let expected_batches: Vec<RecordBatch> = (0..3)
601 .map(|batch_id| {
602 let value_range = batch_id * 10..batch_id * 10 + 10;
603 let columns: Vec<ArrayRef> = vec![
604 Arc::new(Int32Array::from_iter_values(value_range.clone())),
605 Arc::new(Float32Array::from_iter(
606 value_range.clone().map(|v| v as f32),
607 )),
608 Arc::new(StringArray::from_iter_values(
609 value_range.clone().map(|v| format!("str-{v}")),
610 )),
611 Arc::new(StructArray::new(
612 struct_fields.clone(),
613 vec![
614 Arc::new(Int32Array::from_iter(value_range.clone())),
615 Arc::new(Int32Array::from_iter(value_range)),
616 ],
617 None,
618 )),
619 ];
620 RecordBatch::try_new(schema.clone(), columns).unwrap()
621 })
622 .collect();
623
624 let test_dir = TempStrDir::default();
625 let test_uri = test_dir.as_str();
626 let params = WriteParams {
627 max_rows_per_file: 10,
628 ..Default::default()
629 };
630 let reader =
631 RecordBatchIterator::new(expected_batches.clone().into_iter().map(Ok), schema.clone());
632 Dataset::write(reader, test_uri, Some(params))
633 .await
634 .unwrap();
635
636 TestFixture {
637 dataset: Arc::new(Dataset::open(test_uri).await.unwrap()),
638 _tmp_dir_guard: test_dir,
639 }
640 }
641
642 #[tokio::test]
643 async fn test_take_schema() {
644 let TestFixture { dataset, .. } = test_fixture().await;
645
646 let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]);
647 let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap());
648
649 let config = LanceScanConfig {
651 with_row_id: true,
652 ..Default::default()
653 };
654 let input = Arc::new(LanceScanExec::new(
655 dataset.clone(),
656 dataset.fragments().clone(),
657 None,
658 scan_schema,
659 config,
660 ));
661
662 let projection = dataset
663 .empty_projection()
664 .union_column("s", OnMissing::Error)
665 .unwrap();
666 let take_exec = TakeExec::try_new(dataset, input, projection)
667 .unwrap()
668 .unwrap();
669 let schema = take_exec.schema();
670 assert_eq!(
671 schema.fields.iter().map(|f| f.name()).collect::<Vec<_>>(),
672 vec!["i", ROW_ID, "s"]
673 );
674 }
675
676 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
677 enum TakeInput {
678 Ids,
679 Addrs,
680 IdsAndAddrs,
681 }
682
683 #[rstest]
684 #[tokio::test]
685 async fn test_simple_take(
686 #[values(TakeInput::Ids, TakeInput::Addrs, TakeInput::IdsAndAddrs)] take_input: TakeInput,
687 ) {
688 let TestFixture {
689 dataset,
690 _tmp_dir_guard,
691 } = test_fixture().await;
692
693 let scan_schema = Arc::new(dataset.schema().project(&["i"]).unwrap());
694 let config = LanceScanConfig {
695 with_row_address: take_input != TakeInput::Ids,
696 with_row_id: take_input != TakeInput::Addrs,
697 ..Default::default()
698 };
699 let input = Arc::new(LanceScanExec::new(
700 dataset.clone(),
701 dataset.fragments().clone(),
702 None,
703 scan_schema,
704 config,
705 ));
706
707 let projection = dataset
708 .empty_projection()
709 .union_column("s", OnMissing::Error)
710 .unwrap();
711 let take_exec = TakeExec::try_new(dataset, input, projection)
712 .unwrap()
713 .unwrap();
714 let schema = take_exec.schema();
715
716 let mut expected_fields = vec!["i"];
717 if take_input != TakeInput::Addrs {
718 expected_fields.push(ROW_ID);
719 }
720 if take_input != TakeInput::Ids {
721 expected_fields.push(ROW_ADDR);
722 }
723 expected_fields.push("s");
724 assert_eq!(&schema.field_names(), &expected_fields);
725
726 let mut stream = take_exec
727 .execute(0, Arc::new(TaskContext::default()))
728 .unwrap();
729
730 while let Some(batch) = stream.try_next().await.unwrap() {
731 assert_eq!(&batch.schema().field_names(), &expected_fields);
732 }
733 }
734
735 #[tokio::test]
736 async fn test_take_order() {
737 let TestFixture {
738 dataset,
739 _tmp_dir_guard,
740 } = test_fixture().await;
741
742 let data = dataset
744 .scan()
745 .project(&["s"])
746 .unwrap()
747 .with_row_address()
748 .try_into_batch()
749 .await
750 .unwrap();
751 let indices = UInt64Array::from(vec![8, 13, 1, 7, 4, 5, 12, 9, 10, 2, 11, 6, 3, 0, 28]);
752 let data = arrow_select::take::take_record_batch(&data, &indices).unwrap();
753
754 let schema = Arc::new(ArrowSchema::new(vec![Field::new(
755 ROW_ADDR,
756 DataType::UInt64,
757 true,
758 )]));
759 let row_addrs = data.project_by_schema(&schema).unwrap();
760
761 let batches = (0..3)
763 .map(|i| {
764 let start = i * 5;
765 row_addrs.slice(start, 5)
766 })
767 .collect::<Vec<_>>();
768
769 let row_addr_stream = futures::stream::iter(batches.clone().into_iter().map(Ok));
770 let row_addr_stream = Box::pin(RecordBatchStreamAdapter::new(schema, row_addr_stream));
771
772 let input = Arc::new(OneShotExec::new(row_addr_stream));
773
774 let projection = dataset
775 .empty_projection()
776 .union_column("s", OnMissing::Error)
777 .unwrap();
778 let take_exec = TakeExec::try_new(dataset, input, projection)
779 .unwrap()
780 .unwrap();
781
782 let stream = take_exec
783 .execute(0, Arc::new(TaskContext::default()))
784 .unwrap();
785
786 let expected = vec![data.slice(0, 5), data.slice(5, 5), data.slice(10, 5)];
787
788 let batches = stream.try_collect::<Vec<_>>().await.unwrap();
789 assert_eq!(batches.len(), 3);
790 for (batch, expected) in batches.into_iter().zip(expected) {
791 assert_eq!(batch.schema().field_names(), vec![ROW_ADDR, "s"]);
792 let expected = expected.project_by_schema(&batch.schema()).unwrap();
793 assert_eq!(batch, expected);
794 }
795
796 let metrics = take_exec.metrics().unwrap();
797 assert_eq!(metrics.output_rows(), Some(15));
798 assert_eq!(metrics.find_count("batches_processed").unwrap().value(), 3);
799 }
800
801 #[tokio::test]
802 async fn test_take_struct() {
803 let TestFixture {
806 dataset,
807 _tmp_dir_guard,
808 } = test_fixture().await;
809
810 let scan_schema = Arc::new(dataset.schema().project(&["struct.y"]).unwrap());
811
812 let config = LanceScanConfig {
813 with_row_address: true,
814 ..Default::default()
815 };
816 let input = Arc::new(LanceScanExec::new(
817 dataset.clone(),
818 dataset.fragments().clone(),
819 None,
820 scan_schema,
821 config,
822 ));
823
824 let projection = dataset
825 .empty_projection()
826 .union_column("struct.x", OnMissing::Error)
827 .unwrap();
828
829 let take_exec = TakeExec::try_new(dataset, input, projection)
830 .unwrap()
831 .unwrap();
832
833 let expected_schema = ArrowSchema::new(vec![
834 Field::new(
835 "struct",
836 DataType::Struct(Fields::from(vec![
837 Arc::new(Field::new("x", DataType::Int32, false)),
838 Arc::new(Field::new("y", DataType::Int32, false)),
839 ])),
840 false,
841 ),
842 Field::new(ROW_ADDR, DataType::UInt64, true),
843 ]);
844 let schema = take_exec.schema();
845 assert_eq!(schema.as_ref(), &expected_schema);
846
847 let mut stream = take_exec
848 .execute(0, Arc::new(TaskContext::default()))
849 .unwrap();
850
851 while let Some(batch) = stream.try_next().await.unwrap() {
852 assert_eq!(batch.schema().as_ref(), &expected_schema);
853 }
854 }
855
856 #[tokio::test]
857 async fn test_take_no_row_addr() {
858 let TestFixture { dataset, .. } = test_fixture().await;
859
860 let scan_arrow_schema = ArrowSchema::new(vec![Field::new("i", DataType::Int32, false)]);
861 let scan_schema = Arc::new(Schema::try_from(&scan_arrow_schema).unwrap());
862
863 let projection = dataset
864 .empty_projection()
865 .union_column("s", OnMissing::Error)
866 .unwrap();
867
868 let input = Arc::new(LanceScanExec::new(
870 dataset.clone(),
871 dataset.fragments().clone(),
872 None,
873 scan_schema,
874 LanceScanConfig::default(),
875 ));
876 assert!(TakeExec::try_new(dataset, input, projection).is_err());
877 }
878
879 #[tokio::test]
880 async fn test_with_new_children() -> Result<()> {
881 let TestFixture { dataset, .. } = test_fixture().await;
882
883 let config = LanceScanConfig {
884 with_row_id: true,
885 ..Default::default()
886 };
887
888 let input_schema = Arc::new(dataset.schema().project(&["i"])?);
889 let projection = dataset
890 .empty_projection()
891 .union_column("s", OnMissing::Error)
892 .unwrap();
893
894 let input = Arc::new(LanceScanExec::new(
895 dataset.clone(),
896 dataset.fragments().clone(),
897 None,
898 input_schema,
899 config,
900 ));
901
902 assert_eq!(input.schema().field_names(), vec!["i", ROW_ID],);
903 let take_exec = TakeExec::try_new(dataset.clone(), input.clone(), projection)?.unwrap();
904 assert_eq!(take_exec.schema().field_names(), vec!["i", ROW_ID, "s"],);
905
906 let projection = dataset
907 .empty_projection()
908 .union_columns(["s", "f"], OnMissing::Error)
909 .unwrap();
910
911 let outer_take =
912 Arc::new(TakeExec::try_new(dataset, Arc::new(take_exec), projection)?.unwrap());
913 assert_eq!(
914 outer_take.schema().field_names(),
915 vec!["i", ROW_ID, "s", "f"],
916 );
917
918 let edited = outer_take.with_new_children(vec![input])?;
920 assert_eq!(edited.schema().field_names(), vec!["i", ROW_ID, "f", "s"],);
921 Ok(())
922 }
923
924 #[test]
925 fn no_context_take() {
926 let fixture = NoContextTestFixture::new();
930 let arc_dasaset = Arc::new(fixture.dataset);
931
932 let input = lance_datagen::gen_batch()
933 .col(ROW_ID, lance_datagen::array::step::<UInt64Type>())
934 .into_df_exec(RowCount::from(50), BatchCount::from(2));
935
936 let take = TakeExec::try_new(
937 arc_dasaset.clone(),
938 input,
939 arc_dasaset
940 .empty_projection()
941 .union_column("text", OnMissing::Error)
942 .unwrap(),
943 )
944 .unwrap()
945 .unwrap();
946
947 take.execute(0, Arc::new(TaskContext::default())).unwrap();
948 }
949}