Skip to main content

datafusion_index_provider/physical_plan/exec/
fetch.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::fmt;
20use std::sync::Arc;
21
22use datafusion::common::Statistics;
23use datafusion::error::{DataFusionError, Result};
24use datafusion::execution::context::TaskContext;
25use datafusion::execution::SendableRecordBatchStream;
26use futures::future::BoxFuture;
27use futures::FutureExt;
28
29use std::pin::Pin;
30use std::task::{Context, Poll};
31
32use datafusion::arrow::datatypes::{Schema, SchemaRef};
33use datafusion::arrow::record_batch::RecordBatch;
34use datafusion::physical_expr::EquivalenceProperties;
35use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
36use datafusion::physical_plan::{
37    DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PlanProperties,
38    RecordBatchStream,
39};
40use futures::stream::{Stream, StreamExt};
41
42use crate::physical_plan::exec::index::IndexScanExec;
43use crate::physical_plan::exec::sequential_union::SequentialUnionExec;
44use crate::physical_plan::fetcher::RecordFetcher;
45use crate::physical_plan::joins::try_create_index_lookup_join;
46use crate::types::{IndexFilter, IndexFilters, UnionMode};
47use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
48use datafusion::physical_plan::empty::EmptyExec;
49use datafusion::physical_plan::expressions::Column;
50use datafusion::physical_plan::projection::ProjectionExec;
51use datafusion::physical_plan::union::UnionExec;
52use datafusion::physical_plan::PhysicalExpr;
53
54/// Physical plan node for fetching records from a [`RecordFetcher`] using
55/// row IDs produced by one or more index scans.
56///
57/// This operator takes one or more [`IndexFilter`]s, builds an input plan
58/// to produce row IDs (by scanning and joining index results), and then uses
59/// a [`RecordFetcher`] to retrieve the actual data for those row IDs.
60#[derive(Debug)]
61pub struct RecordFetchExec {
62    indexes: Arc<IndexFilters>,
63    limit: Option<usize>,
64    plan_properties: PlanProperties,
65    record_fetcher: Arc<dyn RecordFetcher>,
66    /// The input plan that produces the row IDs.
67    input: Arc<dyn ExecutionPlan>,
68    metrics: ExecutionPlanMetricsSet,
69    schema: SchemaRef,
70    /// Controls how union operations are executed for OR conditions.
71    union_mode: UnionMode,
72}
73
74impl RecordFetchExec {
75    /// Create a new `RecordFetchExec` plan.
76    ///
77    /// # Arguments
78    /// * `indexes` - Index filters to use for scanning
79    /// * `limit` - Optional limit on the number of rows
80    /// * `record_fetcher` - The fetcher to retrieve records by row ID
81    /// * `schema` - Output schema
82    /// * `union_mode` - Controls whether OR conditions use parallel or sequential union
83    pub fn try_new(
84        indexes: Vec<IndexFilter>,
85        limit: Option<usize>,
86        record_fetcher: Arc<dyn RecordFetcher>,
87        schema: SchemaRef,
88        union_mode: UnionMode,
89    ) -> Result<Self> {
90        if indexes.is_empty() {
91            return Err(DataFusionError::Plan(
92                "RecordFetchExec requires at least one index".to_string(),
93            ));
94        }
95
96        if indexes.len() > 1 {
97            return Err(DataFusionError::Internal(
98                "RecordFetchExec expects a single root IndexFilter".to_string(),
99            ));
100        }
101
102        let input = match indexes.first() {
103            Some(index_filter) => Self::build_scan_exec(index_filter, limit, union_mode)?,
104            None => {
105                return Err(DataFusionError::Plan(
106                    "RecordFetchExec requires at least one index".to_string(),
107                ));
108            }
109        };
110        let eq_properties = EquivalenceProperties::new(schema.clone());
111        let plan_properties = PlanProperties::new(
112            eq_properties,
113            Partitioning::UnknownPartitioning(1),
114            input.properties().emission_type,
115            input.properties().boundedness,
116        );
117
118        Ok(Self {
119            indexes: indexes.into(),
120            limit,
121            plan_properties,
122            record_fetcher,
123            input,
124            metrics: ExecutionPlanMetricsSet::new(),
125            schema,
126            union_mode,
127        })
128    }
129
130    /// Builds the input execution plan that produces primary key values based on the IndexFilter structure.
131    ///
132    /// This method is the core of the index-based execution plan generation. It recursively
133    /// processes the [`IndexFilter`] tree to create an optimized physical plan that efficiently
134    /// produces primary key values matching the query filters.
135    ///
136    /// # Plan Generation Strategy
137    ///
138    /// The method generates different execution plans based on the [`IndexFilter`] variant:
139    ///
140    /// ## [`IndexFilter::Single`] - Direct Index Scan
141    /// Creates a single [`IndexScanExec`] that directly scans the specified index with the given filter.
142    /// This is the most efficient case with minimal overhead.
143    ///
144    /// ```text
145    /// IndexScanExec(index, filter)
146    /// ```
147    ///
148    /// ## [`IndexFilter::And`] - Index Intersection via Joins
149    /// Builds a left-deep tree of joins to intersect primary key values from multiple indexes.
150    /// The joins are performed on all primary key columns to find rows that satisfy ALL conditions.
151    /// Uses [`crate::physical_plan::joins::try_create_index_lookup_join`] which selects between
152    /// HashJoin and SortMergeJoin based on input ordering.
153    ///
154    /// ```text
155    /// Projection(PK columns)
156    /// └── HashJoin/SortMergeJoin(
157    ///       Projection(PK columns)
158    ///       └── HashJoin/SortMergeJoin(
159    ///             IndexScanExec(index1, filter1),
160    ///             IndexScanExec(index2, filter2)
161    ///           ),
162    ///       IndexScanExec(index3, filter3)
163    ///     )
164    /// ```
165    ///
166    /// ## [`IndexFilter::Or`] - Union with Deduplication
167    /// Creates a [`UnionExec`](datafusion::physical_plan::union::UnionExec) of all index scans followed by an [`AggregateExec`] that groups by
168    /// all primary key columns to automatically deduplicate overlapping results.
169    ///
170    /// ```text
171    /// AggregateExec(GROUP BY PK columns,
172    ///   UnionExec(
173    ///     IndexScanExec(index1, filter1),
174    ///     IndexScanExec(index2, filter2),
175    ///     IndexScanExec(index3, filter3)
176    ///   )
177    /// )
178    /// ```
179    ///
180    /// # Arguments
181    /// * `index_filter` - The [`IndexFilter`] tree specifying which indexes to scan and how to combine them
182    /// * `limit` - Optional limit on the number of rows to return, passed through to individual index scans
183    /// * `union_mode` - Controls whether OR conditions use parallel or sequential union
184    ///
185    /// # Returns
186    /// An [`Arc<dyn ExecutionPlan>`] that produces a stream of primary key values matching the filter criteria.
187    /// The output schema contains all columns from the index schema (the composite primary key).
188    ///
189    /// # Errors
190    /// Returns [`DataFusionError::Plan`] if:
191    /// - An [`IndexFilter::And`] contains no sub-filters
192    /// - Any recursive call to build sub-plans fails
193    /// - Index scan creation fails due to invalid filter expressions
194    fn build_scan_exec(
195        index_filter: &IndexFilter,
196        limit: Option<usize>,
197        union_mode: UnionMode,
198    ) -> Result<Arc<dyn ExecutionPlan>> {
199        match index_filter {
200            IndexFilter::Single { index, filter } => {
201                let schema = index.index_schema();
202                let exec =
203                    IndexScanExec::try_new(index.clone(), vec![filter.clone()], limit, schema)?;
204                Ok(Arc::new(exec))
205            }
206            IndexFilter::And(filters) => {
207                let mut plans = filters
208                    .iter()
209                    .map(|f| Self::build_scan_exec(f, limit, union_mode))
210                    .collect::<Result<Vec<_>>>()?;
211
212                if plans.is_empty() {
213                    return Err(DataFusionError::Plan(
214                        "IndexFilter::And requires at least one sub-filter".to_string(),
215                    ));
216                }
217
218                let mut left = plans.remove(0);
219                let pk_schema = left.schema();
220                while !plans.is_empty() {
221                    let right = plans.remove(0);
222                    let joined = try_create_index_lookup_join(left, right)?;
223                    left = Self::project_to_pk_schema(joined, &pk_schema)?;
224                }
225                Ok(left)
226            }
227            IndexFilter::Or(filters) => {
228                let original_plans = filters
229                    .iter()
230                    .map(|f| Self::build_scan_exec(f, limit, union_mode))
231                    .collect::<Result<Vec<_>>>()?;
232
233                if original_plans.is_empty() {
234                    return Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))));
235                }
236
237                // Derive canonical PK schema from the first plan
238                let canonical_schema = original_plans[0].schema();
239
240                // Normalize all plans to the canonical schema
241                let normalized_plans: Vec<Arc<dyn ExecutionPlan>> = original_plans
242                    .into_iter()
243                    .map(|plan| Self::project_to_pk_schema(plan, &canonical_schema))
244                    .collect::<Result<Vec<_>>>()?;
245
246                // Create union based on mode
247                let union_input: Arc<dyn ExecutionPlan> = match union_mode {
248                    UnionMode::Parallel => UnionExec::try_new(normalized_plans)?,
249                    UnionMode::Sequential => {
250                        Arc::new(SequentialUnionExec::try_new(normalized_plans)?)
251                    }
252                };
253
254                // Create aggregate to deduplicate by ALL primary key columns
255                let group_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = canonical_schema
256                    .fields()
257                    .iter()
258                    .enumerate()
259                    .map(|(i, field)| {
260                        (
261                            Arc::new(Column::new(field.name(), i)) as Arc<dyn PhysicalExpr>,
262                            field.name().to_string(),
263                        )
264                    })
265                    .collect();
266
267                let group_by = PhysicalGroupBy::new_single(group_exprs);
268
269                let agg_exec = AggregateExec::try_new(
270                    AggregateMode::Single,
271                    group_by,
272                    vec![],
273                    vec![],
274                    union_input,
275                    canonical_schema,
276                )?;
277
278                Ok(Arc::new(agg_exec))
279            }
280        }
281    }
282
283    /// Projects a plan's output down to the primary key schema columns.
284    ///
285    /// After a join, the output may contain duplicate columns from both sides
286    /// (e.g., `(a_left, b_left, a_right, b_right)`). This projects to just the
287    /// first occurrence of each PK column.
288    fn project_to_pk_schema(
289        plan: Arc<dyn ExecutionPlan>,
290        pk_schema: &SchemaRef,
291    ) -> Result<Arc<dyn ExecutionPlan>> {
292        let plan_schema = plan.schema();
293
294        // Short-circuit if schemas already match
295        if plan_schema.fields().len() == pk_schema.fields().len()
296            && pk_schema
297                .fields()
298                .iter()
299                .enumerate()
300                .all(|(i, f)| plan_schema.field(i) == f.as_ref())
301        {
302            return Ok(plan);
303        }
304
305        // Build projection selecting first occurrence of each PK column
306        let exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = pk_schema
307            .fields()
308            .iter()
309            .map(|field| {
310                let idx = plan_schema
311                    .fields()
312                    .iter()
313                    .position(|f| f.name() == field.name())
314                    .ok_or_else(|| {
315                        DataFusionError::Plan(format!(
316                            "Primary key column '{}' not found in plan schema: {:?}",
317                            field.name(),
318                            plan_schema
319                        ))
320                    })?;
321                Ok((
322                    Arc::new(Column::new(field.name(), idx)) as Arc<dyn PhysicalExpr>,
323                    field.name().to_string(),
324                ))
325            })
326            .collect::<Result<Vec<_>>>()?;
327
328        Ok(Arc::new(ProjectionExec::try_new(exprs, plan)?))
329    }
330}
331
332impl DisplayAs for RecordFetchExec {
333    fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
334        match t {
335            DisplayFormatType::Default
336            | DisplayFormatType::Verbose
337            | DisplayFormatType::TreeRender => {
338                let index_names: Vec<_> = self.indexes.iter().map(|i| i.to_string()).collect();
339                write!(
340                    f,
341                    "RecordFetchExec: indexes=[{}], limit={:?}",
342                    index_names.join(", "),
343                    self.limit
344                )
345            }
346        }
347    }
348}
349
350impl ExecutionPlan for RecordFetchExec {
351    /// Return a reference to the name of this execution plan.
352    fn name(&self) -> &str {
353        "RecordFetchExec"
354    }
355
356    /// Return a reference to the logical plan as [`Any`] so that it can be
357    /// downcast to a specific implementation.
358    fn as_any(&self) -> &dyn Any {
359        self
360    }
361
362    /// Get the schema of this execution plan
363    fn schema(&self) -> SchemaRef {
364        self.schema.clone()
365    }
366
367    /// Get the properties for this execution plan
368    fn properties(&self) -> &PlanProperties {
369        &self.plan_properties
370    }
371
372    /// Returns the children of this [`ExecutionPlan`].
373    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
374        vec![&self.input]
375    }
376
377    fn required_input_distribution(&self) -> Vec<Distribution> {
378        // RecordFetchExec requires a single partition input because it merges
379        // results from multiple index scans.
380        vec![Distribution::SinglePartition]
381    }
382
383    /// Create a new [`ExecutionPlan`] with new children.
384    fn with_new_children(
385        self: Arc<Self>,
386        children: Vec<Arc<dyn ExecutionPlan>>,
387    ) -> Result<Arc<dyn ExecutionPlan>> {
388        if children.len() != 1 {
389            return Err(DataFusionError::Internal(
390                "RecordFetchExec should have exactly one child".to_string(),
391            ));
392        }
393        Ok(Arc::new(RecordFetchExec {
394            indexes: self.indexes.clone(),
395            limit: self.limit,
396            plan_properties: self.plan_properties.clone(),
397            record_fetcher: self.record_fetcher.clone(),
398            input: children[0].clone(),
399            metrics: self.metrics.clone(),
400            schema: self.schema.clone(),
401            union_mode: self.union_mode,
402        }))
403    }
404
405    /// Executes this plan and returns a stream of `RecordBatch`es.
406    fn execute(
407        &self,
408        partition: usize,
409        context: Arc<TaskContext>,
410    ) -> Result<SendableRecordBatchStream> {
411        if partition != 0 {
412            return Err(DataFusionError::Internal(format!(
413                "RecordFetchExec executed with partition {partition} but expected 0"
414            )));
415        }
416
417        let input_stream = self.input.execute(0, context)?;
418        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
419
420        Ok(Box::pin(RecordFetchStream::new(
421            input_stream,
422            self.record_fetcher.clone(),
423            baseline_metrics,
424        )))
425    }
426
427    /// Get the statistics for this execution plan.
428    fn statistics(&self) -> Result<Statistics> {
429        Ok(Statistics::new_unknown(&self.schema()))
430    }
431}
432
433/// A stream that fetches records using row IDs from an input stream.
434pub struct RecordFetchStream {
435    /// The schema of the output data.
436    schema: SchemaRef,
437    /// Execution metrics.
438    baseline_metrics: BaselineMetrics,
439    /// The state of the stream.
440    state: FetchState,
441}
442
443/// A future that resolves to a fetched `RecordBatch` and the reclaimed
444/// input stream and fetcher.
445type FetchFuture = BoxFuture<
446    'static,
447    Result<(
448        SendableRecordBatchStream,
449        Arc<dyn RecordFetcher>,
450        RecordBatch,
451    )>,
452>;
453
454/// The state of the `RecordFetchStream`.
455enum FetchState {
456    /// Reading from the input stream.
457    ReadingInput {
458        input: SendableRecordBatchStream,
459        fetcher: Arc<dyn RecordFetcher>,
460    },
461    /// Fetching a batch of records. The future returns the input stream and
462    /// fetcher so they can be reclaimed.
463    Fetching(FetchFuture),
464    /// An error occurred.
465    Error,
466}
467
468impl RecordFetchStream {
469    /// Create a new `RecordFetchStream`.
470    pub fn new(
471        input: SendableRecordBatchStream,
472        fetcher: Arc<dyn RecordFetcher>,
473        baseline_metrics: BaselineMetrics,
474    ) -> Self {
475        let schema = fetcher.schema();
476        let state = FetchState::ReadingInput { input, fetcher };
477        Self {
478            schema,
479            baseline_metrics,
480            state,
481        }
482    }
483}
484
485impl Stream for RecordFetchStream {
486    type Item = Result<RecordBatch>;
487
488    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
489        loop {
490            match std::mem::replace(&mut self.state, FetchState::Error) {
491                FetchState::ReadingInput { mut input, fetcher } => {
492                    match input.poll_next_unpin(cx) {
493                        Poll::Ready(Some(Ok(batch))) if batch.num_rows() > 0 => {
494                            // Start async fetch for non-empty batch
495                            let fut = {
496                                let fetcher = fetcher.clone();
497                                async move {
498                                    fetcher
499                                        .fetch(batch)
500                                        .await
501                                        .map(|batch| (input, fetcher, batch))
502                                }
503                                .boxed()
504                            };
505                            self.state = FetchState::Fetching(fut);
506                        }
507                        Poll::Ready(Some(Ok(_))) => {
508                            // Empty batch - continue reading
509                            self.state = FetchState::ReadingInput { input, fetcher };
510                        }
511                        Poll::Ready(Some(Err(e))) => {
512                            return self.baseline_metrics.record_poll(Poll::Ready(Some(Err(e))));
513                        }
514                        Poll::Ready(None) => {
515                            return self.baseline_metrics.record_poll(Poll::Ready(None));
516                        }
517                        Poll::Pending => {
518                            self.state = FetchState::ReadingInput { input, fetcher };
519                            return self.baseline_metrics.record_poll(Poll::Pending);
520                        }
521                    }
522                }
523                FetchState::Fetching(mut fut) => {
524                    match fut.as_mut().poll(cx) {
525                        Poll::Ready(Ok((input, fetcher, batch))) if batch.num_rows() > 0 => {
526                            // Yield non-empty batch and prepare for next input
527                            self.state = FetchState::ReadingInput { input, fetcher };
528                            return self
529                                .baseline_metrics
530                                .record_poll(Poll::Ready(Some(Ok(batch))));
531                        }
532                        Poll::Ready(Ok((input, fetcher, _))) => {
533                            // Empty batch - continue reading
534                            self.state = FetchState::ReadingInput { input, fetcher };
535                        }
536                        Poll::Ready(Err(e)) => {
537                            return self.baseline_metrics.record_poll(Poll::Ready(Some(Err(e))));
538                        }
539                        Poll::Pending => {
540                            self.state = FetchState::Fetching(fut);
541                            return self.baseline_metrics.record_poll(Poll::Pending);
542                        }
543                    }
544                }
545                FetchState::Error => {
546                    return self.baseline_metrics.record_poll(Poll::Ready(None));
547                }
548            }
549        }
550    }
551}
552
553impl fmt::Debug for RecordFetchStream {
554    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
555        f.debug_struct("RecordFetchStream")
556            .field("schema", &self.schema)
557            .field("baseline_metrics", &self.baseline_metrics)
558            .finish()
559    }
560}
561
562impl RecordBatchStream for RecordFetchStream {
563    fn schema(&self) -> SchemaRef {
564        self.schema.clone()
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571    use crate::physical_plan::create_index_schema;
572    use crate::physical_plan::Index;
573    use async_trait::async_trait;
574    use datafusion::arrow::array::StringArray;
575    use datafusion::arrow::array::UInt64Array;
576    use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
577    use datafusion::arrow::record_batch::RecordBatch;
578    use datafusion::common::Statistics;
579    use datafusion::logical_expr::Expr;
580    use datafusion::logical_expr::{col, lit};
581    use datafusion::physical_plan::joins::HashJoinExec;
582    use datafusion::physical_plan::memory::MemoryStream;
583    use datafusion::prelude::SessionContext;
584    use std::any::Any;
585    use std::sync::Mutex;
586    use std::time::Duration;
587
588    const PK_COL: &str = "id";
589
590    // --- Mock Index ---
591    #[derive(Debug)]
592    struct MockIndex {
593        schema: SchemaRef,
594        scan_called: Mutex<bool>,
595        batches: Vec<RecordBatch>,
596    }
597
598    impl MockIndex {
599        fn new(batches: Vec<RecordBatch>) -> Self {
600            Self {
601                schema: create_index_schema([Field::new(PK_COL, DataType::UInt64, false)]),
602                scan_called: Mutex::new(false),
603                batches,
604            }
605        }
606    }
607
608    impl Index for MockIndex {
609        fn as_any(&self) -> &dyn Any {
610            self
611        }
612
613        fn name(&self) -> &str {
614            "mock_index"
615        }
616
617        fn index_schema(&self) -> SchemaRef {
618            self.schema.clone()
619        }
620
621        fn table_name(&self) -> &str {
622            "mock_table"
623        }
624
625        fn column_name(&self) -> &str {
626            "mock_column"
627        }
628
629        fn scan(
630            &self,
631            _filters: &[Expr],
632            _limit: Option<usize>,
633        ) -> Result<SendableRecordBatchStream> {
634            *self.scan_called.lock().unwrap() = true;
635            let stream = MemoryStream::try_new(self.batches.clone(), self.schema.clone(), None)?;
636            Ok(Box::pin(stream))
637        }
638
639        fn statistics(&self) -> Statistics {
640            Statistics::new_unknown(&self.schema)
641        }
642    }
643
644    // --- Mock Record Fetcher ---
645    #[derive(Debug, Clone)]
646    struct MockRecordFetcher {
647        schema: SchemaRef,
648    }
649
650    impl MockRecordFetcher {
651        fn new() -> Self {
652            Self {
653                schema: Arc::new(Schema::new(vec![
654                    Field::new(PK_COL, DataType::UInt64, false),
655                    Field::new("name", DataType::Utf8, false),
656                ])),
657            }
658        }
659
660        fn with_data(self) -> impl RecordFetcher {
661            #[derive(Debug)]
662            struct MockFetcherWithData {
663                schema: SchemaRef,
664            }
665
666            #[async_trait]
667            impl RecordFetcher for MockFetcherWithData {
668                fn schema(&self) -> SchemaRef {
669                    self.schema.clone()
670                }
671
672                async fn fetch(&self, index_batch: RecordBatch) -> Result<RecordBatch> {
673                    let row_ids = index_batch
674                        .column_by_name(PK_COL)
675                        .unwrap()
676                        .as_any()
677                        .downcast_ref::<UInt64Array>()
678                        .unwrap();
679
680                    let names: Vec<_> = row_ids
681                        .values()
682                        .iter()
683                        .map(|id| format!("name_{id}"))
684                        .collect();
685
686                    Ok(RecordBatch::try_new(
687                        self.schema.clone(),
688                        vec![
689                            Arc::new(row_ids.clone()),
690                            Arc::new(StringArray::from(names)),
691                        ],
692                    )?)
693                }
694            }
695
696            MockFetcherWithData {
697                schema: self.schema,
698            }
699        }
700    }
701
702    #[async_trait]
703    impl RecordFetcher for MockRecordFetcher {
704        fn schema(&self) -> SchemaRef {
705            self.schema.clone()
706        }
707
708        async fn fetch(&self, _index_batch: RecordBatch) -> Result<RecordBatch> {
709            unimplemented!("MockRecordFetcher::fetch should not be called in these tests")
710        }
711    }
712
713    // --- Slow Record Fetcher ---
714    #[derive(Debug)]
715    struct SlowRecordFetcher {
716        schema: SchemaRef,
717        names: Vec<String>,
718    }
719
720    impl SlowRecordFetcher {
721        fn new(names: Vec<String>) -> Self {
722            Self {
723                schema: Arc::new(Schema::new(vec![
724                    Field::new(PK_COL, DataType::UInt64, false),
725                    Field::new("name", DataType::Utf8, false),
726                ])),
727                names,
728            }
729        }
730    }
731
732    #[async_trait]
733    impl RecordFetcher for SlowRecordFetcher {
734        fn schema(&self) -> SchemaRef {
735            self.schema.clone()
736        }
737
738        async fn fetch(&self, index_batch: RecordBatch) -> Result<RecordBatch> {
739            // Simulate a delay
740            tokio::time::sleep(Duration::from_millis(20)).await;
741
742            let row_ids = index_batch
743                .column_by_name(PK_COL)
744                .unwrap()
745                .as_any()
746                .downcast_ref::<UInt64Array>()
747                .unwrap();
748
749            // add a delay between each row
750            let mut names = Vec::with_capacity(row_ids.len());
751            for id in row_ids.values().iter() {
752                // simulate an await point
753                tokio::time::sleep(Duration::from_millis(20)).await;
754                names.push(self.names[*id as usize].clone());
755            }
756
757            Ok(RecordBatch::try_new(
758                self.schema.clone(),
759                vec![
760                    Arc::new(row_ids.clone()),
761                    Arc::new(StringArray::from(names)),
762                ],
763            )?)
764        }
765    }
766
767    #[tokio::test]
768    async fn test_record_fetch_exec_slow_input() {
769        let session_ctx = SessionContext::new();
770        let _task_ctx = session_ctx.task_ctx();
771        let schema = Arc::new(Schema::new(vec![Field::new(
772            PK_COL,
773            DataType::UInt64,
774            false,
775        )]));
776
777        // create a memoryStream of 5 rows
778        let input_stream = MemoryStream::try_new(
779            vec![RecordBatch::try_new(
780                schema.clone(),
781                vec![Arc::new(UInt64Array::from(vec![0, 1, 2, 3, 4]))],
782            )
783            .expect("Failed to create RecordBatch")],
784            schema.clone(),
785            None,
786        )
787        .expect("Failed to create MemoryStream");
788
789        let fetcher = Arc::new(SlowRecordFetcher::new(vec![
790            "name_0".to_string(),
791            "name_1".to_string(),
792            "name_2".to_string(),
793            "name_3".to_string(),
794            "name_4".to_string(),
795        ]));
796        let metrics = ExecutionPlanMetricsSet::new();
797        let baseline_metrics = BaselineMetrics::new(&metrics, 0);
798
799        let mut stream = RecordFetchStream::new(Box::pin(input_stream), fetcher, baseline_metrics);
800
801        let mut total_rows = 0;
802        while let Some(batch_result) = stream.next().await {
803            let batch = batch_result.unwrap();
804            total_rows += batch.num_rows();
805        }
806
807        assert_eq!(total_rows, 5, "Should have fetched all 5 rows");
808    }
809
810    #[tokio::test]
811    async fn test_record_fetch_exec_slow_and_multiple() {
812        let session_ctx = SessionContext::new();
813        let _task_ctx = session_ctx.task_ctx();
814        let schema = Arc::new(Schema::new(vec![Field::new(
815            PK_COL,
816            DataType::UInt64,
817            false,
818        )]));
819
820        // create a memoryStream of 5 rows
821        let input_stream = MemoryStream::try_new(
822            vec![
823                RecordBatch::try_new(
824                    schema.clone(),
825                    vec![Arc::new(UInt64Array::from(vec![0, 1, 2]))],
826                )
827                .expect("Failed to create RecordBatch"),
828                RecordBatch::try_new(
829                    schema.clone(),
830                    vec![Arc::new(UInt64Array::from(vec![3, 4]))],
831                )
832                .expect("Failed to create RecordBatch"),
833            ],
834            schema.clone(),
835            None,
836        )
837        .expect("Failed to create MemoryStream");
838
839        let fetcher = Arc::new(SlowRecordFetcher::new(vec![
840            "name_0".to_string(),
841            "name_1".to_string(),
842            "name_2".to_string(),
843            "name_3".to_string(),
844            "name_4".to_string(),
845        ]));
846        let metrics = ExecutionPlanMetricsSet::new();
847        let baseline_metrics = BaselineMetrics::new(&metrics, 0);
848
849        let mut stream = RecordFetchStream::new(Box::pin(input_stream), fetcher, baseline_metrics);
850
851        let mut total_rows = 0;
852        while let Some(batch_result) = stream.next().await {
853            let batch = batch_result.unwrap();
854            total_rows += batch.num_rows();
855        }
856
857        assert_eq!(total_rows, 5, "Should have fetched all 5 rows");
858    }
859
860    #[tokio::test]
861    async fn test_record_fetch_exec_multiple_recordbatch() {
862        let session_ctx = SessionContext::new();
863        let _task_ctx = session_ctx.task_ctx();
864        let schema = Arc::new(Schema::new(vec![Field::new(
865            PK_COL,
866            DataType::UInt64,
867            false,
868        )]));
869
870        // create a memoryStream of 5 recordBatch
871        let input_stream = MemoryStream::try_new(
872            vec![
873                RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![0]))])
874                    .expect("Failed to create RecordBatch"),
875                RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![1]))])
876                    .expect("Failed to create RecordBatch"),
877                RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![2]))])
878                    .expect("Failed to create RecordBatch"),
879                RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![3]))])
880                    .expect("Failed to create RecordBatch"),
881                RecordBatch::try_new(schema.clone(), vec![Arc::new(UInt64Array::from(vec![4]))])
882                    .expect("Failed to create RecordBatch"),
883            ],
884            schema.clone(),
885            None,
886        )
887        .expect("Failed to create MemoryStream");
888
889        let fetcher = Arc::new(SlowRecordFetcher::new(vec![
890            "name_0".to_string(),
891            "name_1".to_string(),
892            "name_2".to_string(),
893            "name_3".to_string(),
894            "name_4".to_string(),
895        ]));
896        let metrics = ExecutionPlanMetricsSet::new();
897        let baseline_metrics = BaselineMetrics::new(&metrics, 0);
898
899        let mut stream = RecordFetchStream::new(Box::pin(input_stream), fetcher, baseline_metrics);
900
901        let mut total_rows = 0;
902        while let Some(batch_result) = stream.next().await {
903            let batch = batch_result.unwrap();
904            total_rows += batch.num_rows();
905        }
906
907        assert_eq!(total_rows, 5, "Should have fetched all 5 rows");
908    }
909
910    // --- Tests ---
911
912    #[tokio::test]
913    async fn test_record_fetch_stream_eager_with_empty_batches() -> Result<()> {
914        // This test ensures that the stream is "eager" and will skip over empty
915        // input batches to find the next valid one within a single poll cycle.
916
917        // 1. Setup input stream with an empty batch in the middle
918        let schema = Arc::new(Schema::new(vec![Field::new(
919            PK_COL,
920            DataType::UInt64,
921            false,
922        )]));
923        let batch1 = RecordBatch::try_new(
924            schema.clone(),
925            vec![Arc::new(UInt64Array::from(vec![1, 2]))],
926        )?;
927        let empty_batch = RecordBatch::new_empty(schema.clone());
928        let batch2 = RecordBatch::try_new(
929            schema.clone(),
930            vec![Arc::new(UInt64Array::from(vec![3, 4]))],
931        )?;
932        let input_stream = MemoryStream::try_new(vec![batch1, empty_batch, batch2], schema, None)?;
933
934        // 2. Setup fetcher and stream
935        let names = (0..5).map(|i| format!("name_{i}")).collect();
936        let fetcher = Arc::new(SlowRecordFetcher::new(names));
937        let metrics = ExecutionPlanMetricsSet::new();
938        let baseline_metrics = BaselineMetrics::new(&metrics, 0);
939        let stream =
940            RecordFetchStream::new(Box::pin(input_stream), fetcher.clone(), baseline_metrics);
941
942        // 3. Collect results
943        let results = datafusion::physical_plan::common::collect(Box::pin(stream)).await?;
944
945        // 4. Assert results
946        let expected_batch1 = RecordBatch::try_new(
947            fetcher.schema(),
948            vec![
949                Arc::new(UInt64Array::from(vec![1, 2])),
950                Arc::new(StringArray::from(vec!["name_1", "name_2"])),
951            ],
952        )?;
953        let expected_batch2 = RecordBatch::try_new(
954            fetcher.schema(),
955            vec![
956                Arc::new(UInt64Array::from(vec![3, 4])),
957                Arc::new(StringArray::from(vec!["name_3", "name_4"])),
958            ],
959        )?;
960
961        assert_eq!(
962            results.len(),
963            2,
964            "Should have produced two non-empty batches"
965        );
966        assert_eq!(results[0], expected_batch1);
967        assert_eq!(results[1], expected_batch2);
968
969        Ok(())
970    }
971
972    #[tokio::test]
973    async fn test_record_fetch_exec_no_indexes() {
974        let fetcher = Arc::new(MockRecordFetcher::new());
975        let err = RecordFetchExec::try_new(
976            vec![],
977            None,
978            fetcher,
979            Arc::new(Schema::empty()),
980            UnionMode::Parallel,
981        )
982        .unwrap_err();
983        assert!(
984            matches!(err, DataFusionError::Plan(ref msg) if msg == "RecordFetchExec requires at least one index"),
985            "Unexpected error: {err:?}"
986        );
987    }
988
989    #[tokio::test]
990    async fn test_record_fetch_exec_single_index() -> Result<()> {
991        let index_batch = RecordBatch::try_from_iter(vec![(
992            PK_COL,
993            Arc::new(UInt64Array::from(vec![1, 3])) as _,
994        )])?;
995        let index = Arc::new(MockIndex::new(vec![index_batch]));
996        let indexes: Vec<IndexFilter> = vec![IndexFilter::Single {
997            index: index.clone() as Arc<dyn Index>,
998            filter: col("a").eq(lit(1)),
999        }];
1000
1001        let fetcher = Arc::new(MockRecordFetcher::new());
1002        let exec = RecordFetchExec::try_new(
1003            indexes,
1004            None,
1005            fetcher,
1006            Arc::new(Schema::empty()),
1007            UnionMode::Parallel,
1008        )?;
1009
1010        // The input plan should be just the IndexScanExec
1011        assert_eq!(exec.input.name(), "IndexScanExec");
1012        Ok(())
1013    }
1014
1015    #[tokio::test]
1016    async fn test_record_fetch_exec_multiple_indexes() -> Result<()> {
1017        // Create two indexes that return different row IDs
1018        let index1_batch = RecordBatch::try_from_iter(vec![(
1019            PK_COL,
1020            Arc::new(UInt64Array::from(vec![1, 3])) as _,
1021        )])?;
1022        let index1 = Arc::new(MockIndex::new(vec![index1_batch]));
1023
1024        let index2_batch = RecordBatch::try_from_iter(vec![(
1025            PK_COL,
1026            Arc::new(UInt64Array::from(vec![3, 5])) as _,
1027        )])?;
1028        let index2 = Arc::new(MockIndex::new(vec![index2_batch]));
1029
1030        let indexes = vec![IndexFilter::And(vec![
1031            IndexFilter::Single {
1032                index: index1,
1033                filter: col("a").eq(lit(1)),
1034            },
1035            IndexFilter::Single {
1036                index: index2,
1037                filter: col("a").eq(lit(1)),
1038            },
1039        ])];
1040
1041        let fetcher = Arc::new(MockRecordFetcher::new());
1042        let exec = RecordFetchExec::try_new(
1043            indexes,
1044            None,
1045            fetcher,
1046            Arc::new(Schema::empty()),
1047            UnionMode::Parallel,
1048        )?;
1049
1050        // The input plan should be a ProjectionExec wrapping a HashJoinExec
1051        assert_eq!(exec.input.name(), "ProjectionExec");
1052        let projection = exec
1053            .input
1054            .as_any()
1055            .downcast_ref::<ProjectionExec>()
1056            .unwrap();
1057        assert_eq!(projection.children()[0].name(), "HashJoinExec");
1058        Ok(())
1059    }
1060
1061    #[tokio::test]
1062    async fn test_record_fetch_exec_five_indexes() -> Result<()> {
1063        let mut indexes_vec = Vec::new();
1064        for i in 0..5 {
1065            let batch = RecordBatch::try_from_iter(vec![(
1066                PK_COL,
1067                Arc::new(UInt64Array::from(vec![i, i + 1, i + 2])) as _,
1068            )])?;
1069            indexes_vec.push(IndexFilter::Single {
1070                index: Arc::new(MockIndex::new(vec![batch])) as Arc<dyn Index>,
1071                filter: col("a").eq(lit(1)),
1072            });
1073        }
1074
1075        let indexes = vec![IndexFilter::And(indexes_vec)];
1076        let fetcher = Arc::new(MockRecordFetcher::new());
1077        let exec = RecordFetchExec::try_new(
1078            indexes,
1079            None,
1080            fetcher,
1081            Arc::new(Schema::empty()),
1082            UnionMode::Parallel,
1083        )?;
1084
1085        // The input plan should be a ProjectionExec at the top
1086        assert_eq!(exec.input.name(), "ProjectionExec");
1087
1088        fn count_joins(plan: &Arc<dyn ExecutionPlan>) -> usize {
1089            if let Some(join_exec) = plan.as_any().downcast_ref::<HashJoinExec>() {
1090                1 + count_joins(join_exec.children()[0]) + count_joins(join_exec.children()[1])
1091            } else {
1092                plan.children().iter().map(|c| count_joins(c)).sum()
1093            }
1094        }
1095
1096        let join_count = count_joins(&exec.input);
1097        assert_eq!(join_count, 4, "Expected 4 joins for 5 indexes");
1098
1099        Ok(())
1100    }
1101
1102    #[tokio::test]
1103    async fn test_record_fetch_exec_execute() -> Result<()> {
1104        // 1. Setup mocks
1105        let index_batch = RecordBatch::try_from_iter(vec![(
1106            PK_COL,
1107            Arc::new(UInt64Array::from(vec![1, 3, 5])) as _,
1108        )])?;
1109        let index = Arc::new(MockIndex::new(vec![index_batch]));
1110        let indexes = vec![IndexFilter::Single {
1111            index: index.clone() as Arc<dyn Index>,
1112            filter: col("a").eq(lit(1)),
1113        }];
1114
1115        let fetcher = Arc::new(MockRecordFetcher::new().with_data());
1116        let schema = fetcher.schema();
1117
1118        // 2. Create exec plan
1119        let exec =
1120            RecordFetchExec::try_new(indexes, None, fetcher, schema.clone(), UnionMode::Parallel)?;
1121
1122        // 3. Execute and collect results
1123        let task_ctx = Arc::new(TaskContext::default());
1124        let mut stream = exec.execute(0, task_ctx)?;
1125        let mut results = Vec::new();
1126        while let Some(batch) = stream.next().await {
1127            results.push(batch?);
1128        }
1129
1130        // 4. Assert results
1131        let expected_names = vec!["name_1", "name_3", "name_5"];
1132        let expected_batch = RecordBatch::try_new(
1133            schema.clone(),
1134            vec![
1135                Arc::new(UInt64Array::from(vec![1, 3, 5])),
1136                Arc::new(StringArray::from(expected_names)),
1137            ],
1138        )?;
1139
1140        assert_eq!(results.len(), 1);
1141        assert_eq!(results[0], expected_batch);
1142
1143        Ok(())
1144    }
1145
1146    #[tokio::test]
1147    async fn test_record_fetch_exec_execute_empty_input() -> Result<()> {
1148        // 1. Setup mocks with no batches
1149        let index = Arc::new(MockIndex::new(vec![]));
1150        let indexes = vec![IndexFilter::Single {
1151            index: index.clone() as Arc<dyn Index>,
1152            filter: col("a").eq(lit(1)),
1153        }];
1154        let fetcher = Arc::new(MockRecordFetcher::new().with_data());
1155
1156        // 2. Create exec plan
1157        let exec = RecordFetchExec::try_new(
1158            indexes,
1159            None,
1160            fetcher,
1161            Arc::new(Schema::empty()),
1162            UnionMode::Parallel,
1163        )?;
1164
1165        // 3. Execute and collect results
1166        let task_ctx = Arc::new(TaskContext::default());
1167        let mut stream = exec.execute(0, task_ctx)?;
1168        let mut results = Vec::new();
1169        while let Some(batch) = stream.next().await {
1170            results.push(batch?);
1171        }
1172
1173        // 4. Assert results are empty
1174        assert!(results.is_empty());
1175
1176        Ok(())
1177    }
1178
1179    #[tokio::test]
1180    async fn test_record_fetch_exec_execute_multiple_batches() -> Result<()> {
1181        // 1. Setup mocks with multiple batches
1182        let batch1 = RecordBatch::try_from_iter(vec![(
1183            PK_COL,
1184            Arc::new(UInt64Array::from(vec![1, 3])) as _,
1185        )])?;
1186        let batch2 = RecordBatch::try_from_iter(vec![(
1187            PK_COL,
1188            Arc::new(UInt64Array::from(vec![5, 7])) as _,
1189        )])?;
1190        let index = Arc::new(MockIndex::new(vec![batch1, batch2]));
1191        let indexes = vec![IndexFilter::Single {
1192            index: index.clone() as Arc<dyn Index>,
1193            filter: col("a").eq(lit(1)),
1194        }];
1195        let fetcher = Arc::new(MockRecordFetcher::new().with_data());
1196        let schema = fetcher.schema();
1197
1198        // 2. Create exec plan
1199        let exec =
1200            RecordFetchExec::try_new(indexes, None, fetcher, schema.clone(), UnionMode::Parallel)?;
1201
1202        // 3. Execute and collect results
1203        let task_ctx = Arc::new(TaskContext::default());
1204        let results =
1205            datafusion::physical_plan::common::collect(exec.execute(0, task_ctx)?).await?;
1206
1207        // 4. Assert results
1208        let expected_batch1 = RecordBatch::try_new(
1209            schema.clone(),
1210            vec![
1211                Arc::new(UInt64Array::from(vec![1, 3])),
1212                Arc::new(StringArray::from(vec!["name_1", "name_3"])),
1213            ],
1214        )?;
1215        let expected_batch2 = RecordBatch::try_new(
1216            schema.clone(),
1217            vec![
1218                Arc::new(UInt64Array::from(vec![5, 7])),
1219                Arc::new(StringArray::from(vec!["name_5", "name_7"])),
1220            ],
1221        )?;
1222
1223        assert_eq!(results.len(), 2);
1224        assert_eq!(results[0], expected_batch1);
1225        assert_eq!(results[1], expected_batch2);
1226
1227        Ok(())
1228    }
1229
1230    #[tokio::test]
1231    async fn test_record_fetch_exec_fetcher_error() -> Result<()> {
1232        // 1. Setup mocks
1233        #[derive(Debug)]
1234        struct ErrorFetcher;
1235        #[async_trait]
1236        impl RecordFetcher for ErrorFetcher {
1237            fn schema(&self) -> SchemaRef {
1238                Arc::new(Schema::empty())
1239            }
1240            async fn fetch(&self, _index_batch: RecordBatch) -> Result<RecordBatch> {
1241                Err(DataFusionError::Execution("fetcher error".to_string()))
1242            }
1243        }
1244
1245        let index_batch =
1246            RecordBatch::try_from_iter(vec![(PK_COL, Arc::new(UInt64Array::from(vec![1])) as _)])?;
1247        let index = Arc::new(MockIndex::new(vec![index_batch]));
1248        let indexes = vec![IndexFilter::Single {
1249            index: index.clone() as Arc<dyn Index>,
1250            filter: col("a").eq(lit(1)),
1251        }];
1252        let fetcher = Arc::new(ErrorFetcher);
1253
1254        // 2. Create exec plan
1255        let exec = RecordFetchExec::try_new(
1256            indexes,
1257            None,
1258            fetcher,
1259            Arc::new(Schema::empty()),
1260            UnionMode::Parallel,
1261        )?;
1262
1263        // 3. Execute and expect an error
1264        let task_ctx = Arc::new(TaskContext::default());
1265        let result = datafusion::physical_plan::common::collect(exec.execute(0, task_ctx)?).await;
1266
1267        assert!(result.is_err());
1268        assert!(matches!(result.unwrap_err(), DataFusionError::Execution(_)));
1269
1270        Ok(())
1271    }
1272}