1use std::any::Any;
19use std::fmt;
20use std::sync::Arc;
21
22use datafusion::common::Statistics;
23use datafusion::error::{DataFusionError, Result};
24use datafusion::execution::context::TaskContext;
25use datafusion::execution::SendableRecordBatchStream;
26use futures::future::BoxFuture;
27use futures::FutureExt;
28
29use std::pin::Pin;
30use std::task::{Context, Poll};
31
32use datafusion::arrow::datatypes::{Schema, SchemaRef};
33use datafusion::arrow::record_batch::RecordBatch;
34use datafusion::physical_expr::EquivalenceProperties;
35use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
36use datafusion::physical_plan::{
37 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
38 RecordBatchStream,
39};
40use futures::stream::{Stream, StreamExt};
41
42use crate::physical_plan::exec::index::IndexScanExec;
43use crate::physical_plan::exec::sequential_union::SequentialUnionExec;
44use crate::physical_plan::fetcher::RecordFetcher;
45use crate::physical_plan::joins::try_create_index_lookup_join;
46use crate::types::{IndexFilter, IndexFilters, UnionMode};
47use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
48use datafusion::physical_plan::empty::EmptyExec;
49use datafusion::physical_plan::expressions::Column;
50use datafusion::physical_plan::projection::ProjectionExec;
51use datafusion::physical_plan::union::UnionExec;
52use datafusion::physical_plan::PhysicalExpr;
53
54#[derive(Debug)]
61pub struct RecordFetchExec {
62 indexes: Arc<IndexFilters>,
63 limit: Option<usize>,
64 plan_properties: PlanProperties,
65 record_fetcher: Arc<dyn RecordFetcher>,
66 input: Arc<dyn ExecutionPlan>,
68 metrics: ExecutionPlanMetricsSet,
69 schema: SchemaRef,
70 union_mode: UnionMode,
72}
73
74impl RecordFetchExec {
75 pub fn try_new(
84 indexes: Vec<IndexFilter>,
85 limit: Option<usize>,
86 record_fetcher: Arc<dyn RecordFetcher>,
87 schema: SchemaRef,
88 union_mode: UnionMode,
89 ) -> Result<Self> {
90 if indexes.is_empty() {
91 return Err(DataFusionError::Plan(
92 "RecordFetchExec requires at least one index".to_string(),
93 ));
94 }
95
96 if indexes.len() > 1 {
97 return Err(DataFusionError::Internal(
98 "RecordFetchExec expects a single root IndexFilter".to_string(),
99 ));
100 }
101
102 let input = match indexes.first() {
103 Some(index_filter) => Self::build_scan_exec(index_filter, limit, union_mode)?,
104 None => {
105 return Err(DataFusionError::Plan(
106 "RecordFetchExec requires at least one index".to_string(),
107 ));
108 }
109 };
110 let eq_properties = EquivalenceProperties::new(schema.clone());
111 let plan_properties = PlanProperties::new(
112 eq_properties,
113 Partitioning::UnknownPartitioning(1),
114 input.properties().emission_type,
115 input.properties().boundedness,
116 );
117
118 Ok(Self {
119 indexes: indexes.into(),
120 limit,
121 plan_properties,
122 record_fetcher,
123 input,
124 metrics: ExecutionPlanMetricsSet::new(),
125 schema,
126 union_mode,
127 })
128 }
129
130 fn build_scan_exec(
195 index_filter: &IndexFilter,
196 limit: Option<usize>,
197 union_mode: UnionMode,
198 ) -> Result<Arc<dyn ExecutionPlan>> {
199 match index_filter {
200 IndexFilter::Single { index, filter } => {
201 let schema = index.index_schema();
202 let exec =
203 IndexScanExec::try_new(index.clone(), vec![filter.clone()], limit, schema)?;
204 Ok(Arc::new(exec))
205 }
206 IndexFilter::And(filters) => {
207 let mut plans = filters
208 .iter()
209 .map(|f| Self::build_scan_exec(f, limit, union_mode))
210 .collect::<Result<Vec<_>>>()?;
211
212 if plans.is_empty() {
213 return Err(DataFusionError::Plan(
214 "IndexFilter::And requires at least one sub-filter".to_string(),
215 ));
216 }
217
218 let mut left = plans.remove(0);
219 let pk_schema = left.schema();
220 while !plans.is_empty() {
221 let right = plans.remove(0);
222 let joined = try_create_index_lookup_join(left, right)?;
223 left = Self::project_to_pk_schema(joined, &pk_schema)?;
224 }
225 Ok(left)
226 }
227 IndexFilter::Or(filters) => {
228 let original_plans = filters
229 .iter()
230 .map(|f| Self::build_scan_exec(f, limit, union_mode))
231 .collect::<Result<Vec<_>>>()?;
232
233 if original_plans.is_empty() {
234 return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))));
235 }
236
237 let canonical_schema = original_plans[0].schema();
239
240 let normalized_plans: Vec<Arc<dyn ExecutionPlan>> = original_plans
242 .into_iter()
243 .map(|plan| Self::project_to_pk_schema(plan, &canonical_schema))
244 .collect::<Result<Vec<_>>>()?;
245
246 let union_input: Arc<dyn ExecutionPlan> = match union_mode {
248 UnionMode::Parallel => UnionExec::try_new(normalized_plans)?,
249 UnionMode::Sequential => {
250 Arc::new(SequentialUnionExec::try_new(normalized_plans)?)
251 }
252 };
253
254 let group_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = canonical_schema
256 .fields()
257 .iter()
258 .enumerate()
259 .map(|(i, field)| {
260 (
261 Arc::new(Column::new(field.name(), i)) as Arc<dyn PhysicalExpr>,
262 field.name().to_string(),
263 )
264 })
265 .collect();
266
267 let group_by = PhysicalGroupBy::new_single(group_exprs);
268
269 let agg_exec = AggregateExec::try_new(
270 AggregateMode::Single,
271 group_by,
272 vec![],
273 vec![],
274 union_input,
275 canonical_schema,
276 )?;
277
278 Ok(Arc::new(agg_exec))
279 }
280 }
281 }
282
283 fn project_to_pk_schema(
289 plan: Arc<dyn ExecutionPlan>,
290 pk_schema: &SchemaRef,
291 ) -> Result<Arc<dyn ExecutionPlan>> {
292 let plan_schema = plan.schema();
293
294 if plan_schema.fields().len() == pk_schema.fields().len()
296 && pk_schema
297 .fields()
298 .iter()
299 .enumerate()
300 .all(|(i, f)| plan_schema.field(i) == f.as_ref())
301 {
302 return Ok(plan);
303 }
304
305 let exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = pk_schema
307 .fields()
308 .iter()
309 .map(|field| {
310 let idx = plan_schema
311 .fields()
312 .iter()
313 .position(|f| f.name() == field.name())
314 .ok_or_else(|| {
315 DataFusionError::Plan(format!(
316 "Primary key column '{}' not found in plan schema: {:?}",
317 field.name(),
318 plan_schema
319 ))
320 })?;
321 Ok((
322 Arc::new(Column::new(field.name(), idx)) as Arc<dyn PhysicalExpr>,
323 field.name().to_string(),
324 ))
325 })
326 .collect::<Result<Vec<_>>>()?;
327
328 Ok(Arc::new(ProjectionExec::try_new(exprs, plan)?))
329 }
330}
331
332impl DisplayAs for RecordFetchExec {
333 fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
334 match t {
335 DisplayFormatType::Default
336 | DisplayFormatType::Verbose
337 | DisplayFormatType::TreeRender => {
338 let index_names: Vec<_> = self.indexes.iter().map(|i| i.to_string()).collect();
339 write!(
340 f,
341 "RecordFetchExec: indexes=[{}], limit={:?}",
342 index_names.join(", "),
343 self.limit
344 )
345 }
346 }
347 }
348}
349
350impl ExecutionPlan for RecordFetchExec {
351 fn name(&self) -> &str {
353 "RecordFetchExec"
354 }
355
356 fn as_any(&self) -> &dyn Any {
359 self
360 }
361
362 fn schema(&self) -> SchemaRef {
364 self.schema.clone()
365 }
366
367 fn properties(&self) -> &PlanProperties {
369 &self.plan_properties
370 }
371
372 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
374 vec![&self.input]
375 }
376
377 fn required_input_distribution(&self) -> Vec<Distribution> {
378 vec![Distribution::SinglePartition]
381 }
382
383 fn with_new_children(
385 self: Arc<Self>,
386 children: Vec<Arc<dyn ExecutionPlan>>,
387 ) -> Result<Arc<dyn ExecutionPlan>> {
388 if children.len() != 1 {
389 return Err(DataFusionError::Internal(
390 "RecordFetchExec should have exactly one child".to_string(),
391 ));
392 }
393 Ok(Arc::new(RecordFetchExec {
394 indexes: self.indexes.clone(),
395 limit: self.limit,
396 plan_properties: self.plan_properties.clone(),
397 record_fetcher: self.record_fetcher.clone(),
398 input: children[0].clone(),
399 metrics: self.metrics.clone(),
400 schema: self.schema.clone(),
401 union_mode: self.union_mode,
402 }))
403 }
404
405 fn execute(
407 &self,
408 partition: usize,
409 context: Arc<TaskContext>,
410 ) -> Result<SendableRecordBatchStream> {
411 if partition != 0 {
412 return Err(DataFusionError::Internal(format!(
413 "RecordFetchExec executed with partition {partition} but expected 0"
414 )));
415 }
416
417 let input_stream = self.input.execute(0, context)?;
418 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
419
420 Ok(Box::pin(RecordFetchStream::new(
421 input_stream,
422 self.record_fetcher.clone(),
423 baseline_metrics,
424 )))
425 }
426
427 fn statistics(&self) -> Result<Statistics> {
429 Ok(Statistics::new_unknown(&self.schema()))
430 }
431}
432
433pub struct RecordFetchStream {
435 schema: SchemaRef,
437 baseline_metrics: BaselineMetrics,
439 state: FetchState,
441}
442
443type FetchFuture = BoxFuture<
446 'static,
447 Result<(
448 SendableRecordBatchStream,
449 Arc<dyn RecordFetcher>,
450 RecordBatch,
451 )>,
452>;
453
454enum FetchState {
456 ReadingInput {
458 input: SendableRecordBatchStream,
459 fetcher: Arc<dyn RecordFetcher>,
460 },
461 Fetching(FetchFuture),
464 Error,
466}
467
468impl RecordFetchStream {
469 pub fn new(
471 input: SendableRecordBatchStream,
472 fetcher: Arc<dyn RecordFetcher>,
473 baseline_metrics: BaselineMetrics,
474 ) -> Self {
475 let schema = fetcher.schema();
476 let state = FetchState::ReadingInput { input, fetcher };
477 Self {
478 schema,
479 baseline_metrics,
480 state,
481 }
482 }
483}
484
485impl Stream for RecordFetchStream {
486 type Item = Result<RecordBatch>;
487
488 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
489 loop {
490 match std::mem::replace(&mut self.state, FetchState::Error) {
491 FetchState::ReadingInput { mut input, fetcher } => {
492 match input.poll_next_unpin(cx) {
493 Poll::Ready(Some(Ok(batch))) if batch.num_rows() > 0 => {
494 let fut = {
496 let fetcher = fetcher.clone();
497 async move {
498 fetcher
499 .fetch(batch)
500 .await
501 .map(|batch| (input, fetcher, batch))
502 }
503 .boxed()
504 };
505 self.state = FetchState::Fetching(fut);
506 }
507 Poll::Ready(Some(Ok(_))) => {
508 self.state = FetchState::ReadingInput { input, fetcher };
510 }
511 Poll::Ready(Some(Err(e))) => {
512 return self.baseline_metrics.record_poll(Poll::Ready(Some(Err(e))));
513 }
514 Poll::Ready(None) => {
515 return self.baseline_metrics.record_poll(Poll::Ready(None));
516 }
517 Poll::Pending => {
518 self.state = FetchState::ReadingInput { input, fetcher };
519 return self.baseline_metrics.record_poll(Poll::Pending);
520 }
521 }
522 }
523 FetchState::Fetching(mut fut) => {
524 match fut.as_mut().poll(cx) {
525 Poll::Ready(Ok((input, fetcher, batch))) if batch.num_rows() > 0 => {
526 self.state = FetchState::ReadingInput { input, fetcher };
528 return self
529 .baseline_metrics
530 .record_poll(Poll::Ready(Some(Ok(batch))));
531 }
532 Poll::Ready(Ok((input, fetcher, _))) => {
533 self.state = FetchState::ReadingInput { input, fetcher };
535 }
536 Poll::Ready(Err(e)) => {
537 return self.baseline_metrics.record_poll(Poll::Ready(Some(Err(e))));
538 }
539 Poll::Pending => {
540 self.state = FetchState::Fetching(fut);
541 return self.baseline_metrics.record_poll(Poll::Pending);
542 }
543 }
544 }
545 FetchState::Error => {
546 return self.baseline_metrics.record_poll(Poll::Ready(None));
547 }
548 }
549 }
550 }
551}
552
553impl fmt::Debug for RecordFetchStream {
554 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
555 f.debug_struct("RecordFetchStream")
556 .field("schema", &self.schema)
557 .field("baseline_metrics", &self.baseline_metrics)
558 .finish()
559 }
560}
561
562impl RecordBatchStream for RecordFetchStream {
563 fn schema(&self) -> SchemaRef {
564 self.schema.clone()
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571 use crate::physical_plan::create_index_schema;
572 use crate::physical_plan::Index;
573 use async_trait::async_trait;
574 use datafusion::arrow::array::StringArray;
575 use datafusion::arrow::array::UInt64Array;
576 use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
577 use datafusion::arrow::record_batch::RecordBatch;
578 use datafusion::common::Statistics;
579 use datafusion::logical_expr::Expr;
580 use datafusion::logical_expr::{col, lit};
581 use datafusion::physical_plan::joins::HashJoinExec;
582 use datafusion::physical_plan::memory::MemoryStream;
583 use datafusion::prelude::SessionContext;
584 use std::any::Any;
585 use std::sync::Mutex;
586 use std::time::Duration;
587
588 const PK_COL: &str = "id";
589
590 #[derive(Debug)]
592 struct MockIndex {
593 schema: SchemaRef,
594 scan_called: Mutex<bool>,
595 batches: Vec<RecordBatch>,
596 }
597
598 impl MockIndex {
599 fn new(batches: Vec<RecordBatch>) -> Self {
600 Self {
601 schema: create_index_schema([Field::new(PK_COL, DataType::UInt64, false)]),
602 scan_called: Mutex::new(false),
603 batches,
604 }
605 }
606 }
607
608 impl Index for MockIndex {
609 fn as_any(&self) -> &dyn Any {
610 self
611 }
612
613 fn name(&self) -> &str {
614 "mock_index"
615 }
616
617 fn index_schema(&self) -> SchemaRef {
618 self.schema.clone()
619 }
620
621 fn table_name(&self) -> &str {
622 "mock_table"
623 }
624
625 fn column_name(&self) -> &str {
626 "mock_column"
627 }
628
629 fn scan(
630 &self,
631 _filters: &[Expr],
632 _limit: Option<usize>,
633 ) -> Result<SendableRecordBatchStream> {
634 *self.scan_called.lock().unwrap() = true;
635 let stream = MemoryStream::try_new(self.batches.clone(), self.schema.clone(), None)?;
636 Ok(Box::pin(stream))
637 }
638
639 fn statistics(&self) -> Statistics {
640 Statistics::new_unknown(&self.schema)
641 }
642 }
643
644 #[derive(Debug, Clone)]
646 struct MockRecordFetcher {
647 schema: SchemaRef,
648 }
649
650 impl MockRecordFetcher {
651 fn new() -> Self {
652 Self {
653 schema: Arc::new(Schema::new(vec![
654 Field::new(PK_COL, DataType::UInt64, false),
655 Field::new("name", DataType::Utf8, false),
656 ])),
657 }
658 }
659
660 fn with_data(self) -> impl RecordFetcher {
661 #[derive(Debug)]
662 struct MockFetcherWithData {
663 schema: SchemaRef,
664 }
665
666 #[async_trait]
667 impl RecordFetcher for MockFetcherWithData {
668 fn schema(&self) -> SchemaRef {
669 self.schema.clone()
670 }
671
672 async fn fetch(&self, index_batch: RecordBatch) -> Result<RecordBatch> {
673 let row_ids = index_batch
674 .column_by_name(PK_COL)
675 .unwrap()
676 .as_any()
677 .downcast_ref::<UInt64Array>()
678 .unwrap();
679
680 let names: Vec<_> = row_ids
681 .values()
682 .iter()
683 .map(|id| format!("name_{id}"))
684 .collect();
685
686 Ok(RecordBatch::try_new(
687 self.schema.clone(),
688 vec![
689 Arc::new(row_ids.clone()),
690 Arc::new(StringArray::from(names)),
691 ],
692 )?)
693 }
694 }
695
696 MockFetcherWithData {
697 schema: self.schema,
698 }
699 }
700 }
701
702 #[async_trait]
703 impl RecordFetcher for MockRecordFetcher {
704 fn schema(&self) -> SchemaRef {
705 self.schema.clone()
706 }
707
708 async fn fetch(&self, _index_batch: RecordBatch) -> Result<RecordBatch> {
709 unimplemented!("MockRecordFetcher::fetch should not be called in these tests")
710 }
711 }
712
713 #[derive(Debug)]
715 struct SlowRecordFetcher {
716 schema: SchemaRef,
717 names: Vec<String>,
718 }
719
720 impl SlowRecordFetcher {
721 fn new(names: Vec<String>) -> Self {
722 Self {
723 schema: Arc::new(Schema::new(vec![
724 Field::new(PK_COL, DataType::UInt64, false),
725 Field::new("name", DataType::Utf8, false),
726 ])),
727 names,
728 }
729 }
730 }
731
732 #[async_trait]
733 impl RecordFetcher for SlowRecordFetcher {
734 fn schema(&self) -> SchemaRef {
735 self.schema.clone()
736 }
737
738 async fn fetch(&self, index_batch: RecordBatch) -> Result<RecordBatch> {
739 tokio::time::sleep(Duration::from_millis(20)).await;
741
742 let row_ids = index_batch
743 .column_by_name(PK_COL)
744 .unwrap()
745 .as_any()
746 .downcast_ref::<UInt64Array>()
747 .unwrap();
748
749 let mut names = Vec::with_capacity(row_ids.len());
751 for id in row_ids.values().iter() {
752 tokio::time::sleep(Duration::from_millis(20)).await;
754 names.push(self.names[*id as usize].clone());
755 }
756
757 Ok(RecordBatch::try_new(
758 self.schema.clone(),
759 vec![
760 Arc::new(row_ids.clone()),
761 Arc::new(StringArray::from(names)),
762 ],
763 )?)
764 }
765 }
766
767 #[tokio::test]
768 async fn test_record_fetch_exec_slow_input() {
769 let session_ctx = SessionContext::new();
770 let _task_ctx = session_ctx.task_ctx();
771 let schema = Arc::new(Schema::new(vec![Field::new(
772 PK_COL,
773 DataType::UInt64,
774 false,
775 )]));
776
777 let input_stream = MemoryStream::try_new(
779 vec![RecordBatch::try_new(
780 schema.clone(),
781 vec![Arc::new(UInt64Array::from(vec![0, 1, 2, 3, 4]))],
782 )
783 .expect("Failed to create RecordBatch")],
784 schema.clone(),
785 None,
786 )
787 .expect("Failed to create MemoryStream");
788
789 let fetcher = Arc::new(SlowRecordFetcher::new(vec![
790 "name_0".to_string(),
791 "name_1".to_string(),
792 "name_2".to_string(),
793 "name_3".to_string(),
794 "name_4".to_string(),
795 ]));
796 let metrics = ExecutionPlanMetricsSet::new();
797 let baseline_metrics = BaselineMetrics::new(&metrics, 0);
798
799 let mut stream = RecordFetchStream::new(Box::pin(input_stream), fetcher, baseline_metrics);
800
801 let mut total_rows = 0;
802 while let Some(batch_result) = stream.next().await {
803 let batch = batch_result.unwrap();
804 total_rows += batch.num_rows();
805 }
806
807 assert_eq!(total_rows, 5, "Should have fetched all 5 rows");
808 }
809
810 #[tokio::test]
811 async fn test_record_fetch_exec_slow_and_multiple() {
812 let session_ctx = SessionContext::new();
813 let _task_ctx = session_ctx.task_ctx();
814 let schema = Arc::new(Schema::new(vec![Field::new(
815 PK_COL,
816 DataType::UInt64,
817 false,
818 )]));
819
820 let input_stream = MemoryStream::try_new(
822 vec![
823 RecordBatch::try_new(
824 schema.clone(),
825 vec![Arc::new(UInt64Array::from(vec![0, 1, 2]))],
826 )
827 .expect("Failed to create RecordBatch"),
828 RecordBatch::try_new(
829 schema.clone(),
830 vec![Arc::new(UInt64Array::from(vec![3, 4]))],
831 )
832 .expect("Failed to create RecordBatch"),
833 ],
834 schema.clone(),
835 None,
836 )
837 .expect("Failed to create MemoryStream");
838
839 let fetcher = Arc::new(SlowRecordFetcher::new(vec![
840 "name_0".to_string(),
841 "name_1".to_string(),
842 "name_2".to_string(),
843 "name_3".to_string(),
844 "name_4".to_string(),
845 ]));
846 let metrics = ExecutionPlanMetricsSet::new();
847 let baseline_metrics = BaselineMetrics::new(&metrics, 0);
848
849 let mut stream = RecordFetchStream::new(Box::pin(input_stream), fetcher, baseline_metrics);
850
851 let mut total_rows = 0;
852 while let Some(batch_result) = stream.next().await {
853 let batch = batch_result.unwrap();
854 total_rows += batch.num_rows();
855 }
856
857 assert_eq!(total_rows, 5, "Should have fetched all 5 rows");
858 }
859
860 #[tokio::test]
861 async fn test_record_fetch_exec_multiple_recordbatch() {
862 let session_ctx = SessionContext::new();
863 let _task_ctx = session_ctx.task_ctx();
864 let schema = Arc::new(Schema::new(vec![Field::new(
865 PK_COL,
866 DataType::UInt64,
867 false,
868 )]));
869
870 let input_stream = MemoryStream::try_new(
872 vec![
873 RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![0]))])
874 .expect("Failed to create RecordBatch"),
875 RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![1]))])
876 .expect("Failed to create RecordBatch"),
877 RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![2]))])
878 .expect("Failed to create RecordBatch"),
879 RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![3]))])
880 .expect("Failed to create RecordBatch"),
881 RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![4]))])
882 .expect("Failed to create RecordBatch"),
883 ],
884 schema.clone(),
885 None,
886 )
887 .expect("Failed to create MemoryStream");
888
889 let fetcher = Arc::new(SlowRecordFetcher::new(vec![
890 "name_0".to_string(),
891 "name_1".to_string(),
892 "name_2".to_string(),
893 "name_3".to_string(),
894 "name_4".to_string(),
895 ]));
896 let metrics = ExecutionPlanMetricsSet::new();
897 let baseline_metrics = BaselineMetrics::new(&metrics, 0);
898
899 let mut stream = RecordFetchStream::new(Box::pin(input_stream), fetcher, baseline_metrics);
900
901 let mut total_rows = 0;
902 while let Some(batch_result) = stream.next().await {
903 let batch = batch_result.unwrap();
904 total_rows += batch.num_rows();
905 }
906
907 assert_eq!(total_rows, 5, "Should have fetched all 5 rows");
908 }
909
910 #[tokio::test]
913 async fn test_record_fetch_stream_eager_with_empty_batches() -> Result<()> {
914 let schema = Arc::new(Schema::new(vec![Field::new(
919 PK_COL,
920 DataType::UInt64,
921 false,
922 )]));
923 let batch1 = RecordBatch::try_new(
924 schema.clone(),
925 vec![Arc::new(UInt64Array::from(vec![1, 2]))],
926 )?;
927 let empty_batch = RecordBatch::new_empty(schema.clone());
928 let batch2 = RecordBatch::try_new(
929 schema.clone(),
930 vec![Arc::new(UInt64Array::from(vec![3, 4]))],
931 )?;
932 let input_stream = MemoryStream::try_new(vec![batch1, empty_batch, batch2], schema, None)?;
933
934 let names = (0..5).map(|i| format!("name_{i}")).collect();
936 let fetcher = Arc::new(SlowRecordFetcher::new(names));
937 let metrics = ExecutionPlanMetricsSet::new();
938 let baseline_metrics = BaselineMetrics::new(&metrics, 0);
939 let stream =
940 RecordFetchStream::new(Box::pin(input_stream), fetcher.clone(), baseline_metrics);
941
942 let results = datafusion::physical_plan::common::collect(Box::pin(stream)).await?;
944
945 let expected_batch1 = RecordBatch::try_new(
947 fetcher.schema(),
948 vec![
949 Arc::new(UInt64Array::from(vec![1, 2])),
950 Arc::new(StringArray::from(vec!["name_1", "name_2"])),
951 ],
952 )?;
953 let expected_batch2 = RecordBatch::try_new(
954 fetcher.schema(),
955 vec![
956 Arc::new(UInt64Array::from(vec![3, 4])),
957 Arc::new(StringArray::from(vec!["name_3", "name_4"])),
958 ],
959 )?;
960
961 assert_eq!(
962 results.len(),
963 2,
964 "Should have produced two non-empty batches"
965 );
966 assert_eq!(results[0], expected_batch1);
967 assert_eq!(results[1], expected_batch2);
968
969 Ok(())
970 }
971
972 #[tokio::test]
973 async fn test_record_fetch_exec_no_indexes() {
974 let fetcher = Arc::new(MockRecordFetcher::new());
975 let err = RecordFetchExec::try_new(
976 vec![],
977 None,
978 fetcher,
979 Arc::new(Schema::empty()),
980 UnionMode::Parallel,
981 )
982 .unwrap_err();
983 assert!(
984 matches!(err, DataFusionError::Plan(ref msg) if msg == "RecordFetchExec requires at least one index"),
985 "Unexpected error: {err:?}"
986 );
987 }
988
989 #[tokio::test]
990 async fn test_record_fetch_exec_single_index() -> Result<()> {
991 let index_batch = RecordBatch::try_from_iter(vec![(
992 PK_COL,
993 Arc::new(UInt64Array::from(vec![1, 3])) as _,
994 )])?;
995 let index = Arc::new(MockIndex::new(vec![index_batch]));
996 let indexes: Vec<IndexFilter> = vec![IndexFilter::Single {
997 index: index.clone() as Arc<dyn Index>,
998 filter: col("a").eq(lit(1)),
999 }];
1000
1001 let fetcher = Arc::new(MockRecordFetcher::new());
1002 let exec = RecordFetchExec::try_new(
1003 indexes,
1004 None,
1005 fetcher,
1006 Arc::new(Schema::empty()),
1007 UnionMode::Parallel,
1008 )?;
1009
1010 assert_eq!(exec.input.name(), "IndexScanExec");
1012 Ok(())
1013 }
1014
1015 #[tokio::test]
1016 async fn test_record_fetch_exec_multiple_indexes() -> Result<()> {
1017 let index1_batch = RecordBatch::try_from_iter(vec![(
1019 PK_COL,
1020 Arc::new(UInt64Array::from(vec![1, 3])) as _,
1021 )])?;
1022 let index1 = Arc::new(MockIndex::new(vec![index1_batch]));
1023
1024 let index2_batch = RecordBatch::try_from_iter(vec![(
1025 PK_COL,
1026 Arc::new(UInt64Array::from(vec![3, 5])) as _,
1027 )])?;
1028 let index2 = Arc::new(MockIndex::new(vec![index2_batch]));
1029
1030 let indexes = vec![IndexFilter::And(vec![
1031 IndexFilter::Single {
1032 index: index1,
1033 filter: col("a").eq(lit(1)),
1034 },
1035 IndexFilter::Single {
1036 index: index2,
1037 filter: col("a").eq(lit(1)),
1038 },
1039 ])];
1040
1041 let fetcher = Arc::new(MockRecordFetcher::new());
1042 let exec = RecordFetchExec::try_new(
1043 indexes,
1044 None,
1045 fetcher,
1046 Arc::new(Schema::empty()),
1047 UnionMode::Parallel,
1048 )?;
1049
1050 assert_eq!(exec.input.name(), "ProjectionExec");
1052 let projection = exec
1053 .input
1054 .as_any()
1055 .downcast_ref::<ProjectionExec>()
1056 .unwrap();
1057 assert_eq!(projection.children()[0].name(), "HashJoinExec");
1058 Ok(())
1059 }
1060
1061 #[tokio::test]
1062 async fn test_record_fetch_exec_five_indexes() -> Result<()> {
1063 let mut indexes_vec = Vec::new();
1064 for i in 0..5 {
1065 let batch = RecordBatch::try_from_iter(vec![(
1066 PK_COL,
1067 Arc::new(UInt64Array::from(vec![i, i + 1, i + 2])) as _,
1068 )])?;
1069 indexes_vec.push(IndexFilter::Single {
1070 index: Arc::new(MockIndex::new(vec![batch])) as Arc<dyn Index>,
1071 filter: col("a").eq(lit(1)),
1072 });
1073 }
1074
1075 let indexes = vec![IndexFilter::And(indexes_vec)];
1076 let fetcher = Arc::new(MockRecordFetcher::new());
1077 let exec = RecordFetchExec::try_new(
1078 indexes,
1079 None,
1080 fetcher,
1081 Arc::new(Schema::empty()),
1082 UnionMode::Parallel,
1083 )?;
1084
1085 assert_eq!(exec.input.name(), "ProjectionExec");
1087
1088 fn count_joins(plan: &Arc<dyn ExecutionPlan>) -> usize {
1089 if let Some(join_exec) = plan.as_any().downcast_ref::<HashJoinExec>() {
1090 1 + count_joins(join_exec.children()[0]) + count_joins(join_exec.children()[1])
1091 } else {
1092 plan.children().iter().map(|c| count_joins(c)).sum()
1093 }
1094 }
1095
1096 let join_count = count_joins(&exec.input);
1097 assert_eq!(join_count, 4, "Expected 4 joins for 5 indexes");
1098
1099 Ok(())
1100 }
1101
1102 #[tokio::test]
1103 async fn test_record_fetch_exec_execute() -> Result<()> {
1104 let index_batch = RecordBatch::try_from_iter(vec![(
1106 PK_COL,
1107 Arc::new(UInt64Array::from(vec![1, 3, 5])) as _,
1108 )])?;
1109 let index = Arc::new(MockIndex::new(vec![index_batch]));
1110 let indexes = vec![IndexFilter::Single {
1111 index: index.clone() as Arc<dyn Index>,
1112 filter: col("a").eq(lit(1)),
1113 }];
1114
1115 let fetcher = Arc::new(MockRecordFetcher::new().with_data());
1116 let schema = fetcher.schema();
1117
1118 let exec =
1120 RecordFetchExec::try_new(indexes, None, fetcher, schema.clone(), UnionMode::Parallel)?;
1121
1122 let task_ctx = Arc::new(TaskContext::default());
1124 let mut stream = exec.execute(0, task_ctx)?;
1125 let mut results = Vec::new();
1126 while let Some(batch) = stream.next().await {
1127 results.push(batch?);
1128 }
1129
1130 let expected_names = vec!["name_1", "name_3", "name_5"];
1132 let expected_batch = RecordBatch::try_new(
1133 schema.clone(),
1134 vec![
1135 Arc::new(UInt64Array::from(vec![1, 3, 5])),
1136 Arc::new(StringArray::from(expected_names)),
1137 ],
1138 )?;
1139
1140 assert_eq!(results.len(), 1);
1141 assert_eq!(results[0], expected_batch);
1142
1143 Ok(())
1144 }
1145
1146 #[tokio::test]
1147 async fn test_record_fetch_exec_execute_empty_input() -> Result<()> {
1148 let index = Arc::new(MockIndex::new(vec![]));
1150 let indexes = vec![IndexFilter::Single {
1151 index: index.clone() as Arc<dyn Index>,
1152 filter: col("a").eq(lit(1)),
1153 }];
1154 let fetcher = Arc::new(MockRecordFetcher::new().with_data());
1155
1156 let exec = RecordFetchExec::try_new(
1158 indexes,
1159 None,
1160 fetcher,
1161 Arc::new(Schema::empty()),
1162 UnionMode::Parallel,
1163 )?;
1164
1165 let task_ctx = Arc::new(TaskContext::default());
1167 let mut stream = exec.execute(0, task_ctx)?;
1168 let mut results = Vec::new();
1169 while let Some(batch) = stream.next().await {
1170 results.push(batch?);
1171 }
1172
1173 assert!(results.is_empty());
1175
1176 Ok(())
1177 }
1178
1179 #[tokio::test]
1180 async fn test_record_fetch_exec_execute_multiple_batches() -> Result<()> {
1181 let batch1 = RecordBatch::try_from_iter(vec![(
1183 PK_COL,
1184 Arc::new(UInt64Array::from(vec![1, 3])) as _,
1185 )])?;
1186 let batch2 = RecordBatch::try_from_iter(vec![(
1187 PK_COL,
1188 Arc::new(UInt64Array::from(vec![5, 7])) as _,
1189 )])?;
1190 let index = Arc::new(MockIndex::new(vec![batch1, batch2]));
1191 let indexes = vec![IndexFilter::Single {
1192 index: index.clone() as Arc<dyn Index>,
1193 filter: col("a").eq(lit(1)),
1194 }];
1195 let fetcher = Arc::new(MockRecordFetcher::new().with_data());
1196 let schema = fetcher.schema();
1197
1198 let exec =
1200 RecordFetchExec::try_new(indexes, None, fetcher, schema.clone(), UnionMode::Parallel)?;
1201
1202 let task_ctx = Arc::new(TaskContext::default());
1204 let results =
1205 datafusion::physical_plan::common::collect(exec.execute(0, task_ctx)?).await?;
1206
1207 let expected_batch1 = RecordBatch::try_new(
1209 schema.clone(),
1210 vec![
1211 Arc::new(UInt64Array::from(vec![1, 3])),
1212 Arc::new(StringArray::from(vec!["name_1", "name_3"])),
1213 ],
1214 )?;
1215 let expected_batch2 = RecordBatch::try_new(
1216 schema.clone(),
1217 vec![
1218 Arc::new(UInt64Array::from(vec![5, 7])),
1219 Arc::new(StringArray::from(vec!["name_5", "name_7"])),
1220 ],
1221 )?;
1222
1223 assert_eq!(results.len(), 2);
1224 assert_eq!(results[0], expected_batch1);
1225 assert_eq!(results[1], expected_batch2);
1226
1227 Ok(())
1228 }
1229
1230 #[tokio::test]
1231 async fn test_record_fetch_exec_fetcher_error() -> Result<()> {
1232 #[derive(Debug)]
1234 struct ErrorFetcher;
1235 #[async_trait]
1236 impl RecordFetcher for ErrorFetcher {
1237 fn schema(&self) -> SchemaRef {
1238 Arc::new(Schema::empty())
1239 }
1240 async fn fetch(&self, _index_batch: RecordBatch) -> Result<RecordBatch> {
1241 Err(DataFusionError::Execution("fetcher error".to_string()))
1242 }
1243 }
1244
1245 let index_batch =
1246 RecordBatch::try_from_iter(vec![(PK_COL, Arc::new(UInt64Array::from(vec![1])) as _)])?;
1247 let index = Arc::new(MockIndex::new(vec![index_batch]));
1248 let indexes = vec![IndexFilter::Single {
1249 index: index.clone() as Arc<dyn Index>,
1250 filter: col("a").eq(lit(1)),
1251 }];
1252 let fetcher = Arc::new(ErrorFetcher);
1253
1254 let exec = RecordFetchExec::try_new(
1256 indexes,
1257 None,
1258 fetcher,
1259 Arc::new(Schema::empty()),
1260 UnionMode::Parallel,
1261 )?;
1262
1263 let task_ctx = Arc::new(TaskContext::default());
1265 let result = datafusion::physical_plan::common::collect(exec.execute(0, task_ctx)?).await;
1266
1267 assert!(result.is_err());
1268 assert!(matches!(result.unwrap_err(), DataFusionError::Execution(_)));
1269
1270 Ok(())
1271 }
1272}