datafusion_physical_plan/
limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the LIMIT plan
19
20use std::any::Any;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use super::{
27    DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
28    SendableRecordBatchStream, Statistics,
29};
30use crate::execution_plan::{Boundedness, CardinalityEffect};
31use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning};
32
33use arrow::datatypes::SchemaRef;
34use arrow::record_batch::RecordBatch;
35use datafusion_common::{internal_err, Result};
36use datafusion_execution::TaskContext;
37
38use futures::stream::{Stream, StreamExt};
39use log::trace;
40
41/// Limit execution plan
42#[derive(Debug, Clone)]
43pub struct GlobalLimitExec {
44    /// Input execution plan
45    input: Arc<dyn ExecutionPlan>,
46    /// Number of rows to skip before fetch
47    skip: usize,
48    /// Maximum number of rows to fetch,
49    /// `None` means fetching all rows
50    fetch: Option<usize>,
51    /// Execution metrics
52    metrics: ExecutionPlanMetricsSet,
53    cache: PlanProperties,
54}
55
56impl GlobalLimitExec {
57    /// Create a new GlobalLimitExec
58    pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
59        let cache = Self::compute_properties(&input);
60        GlobalLimitExec {
61            input,
62            skip,
63            fetch,
64            metrics: ExecutionPlanMetricsSet::new(),
65            cache,
66        }
67    }
68
69    /// Input execution plan
70    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
71        &self.input
72    }
73
74    /// Number of rows to skip before fetch
75    pub fn skip(&self) -> usize {
76        self.skip
77    }
78
79    /// Maximum number of rows to fetch
80    pub fn fetch(&self) -> Option<usize> {
81        self.fetch
82    }
83
84    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
85    fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
86        PlanProperties::new(
87            input.equivalence_properties().clone(), // Equivalence Properties
88            Partitioning::UnknownPartitioning(1),   // Output Partitioning
89            input.pipeline_behavior(),
90            // Limit operations are always bounded since they output a finite number of rows
91            Boundedness::Bounded,
92        )
93    }
94}
95
96impl DisplayAs for GlobalLimitExec {
97    fn fmt_as(
98        &self,
99        t: DisplayFormatType,
100        f: &mut std::fmt::Formatter,
101    ) -> std::fmt::Result {
102        match t {
103            DisplayFormatType::Default | DisplayFormatType::Verbose => {
104                write!(
105                    f,
106                    "GlobalLimitExec: skip={}, fetch={}",
107                    self.skip,
108                    self.fetch
109                        .map_or_else(|| "None".to_string(), |x| x.to_string())
110                )
111            }
112            DisplayFormatType::TreeRender => {
113                if let Some(fetch) = self.fetch {
114                    writeln!(f, "limit={fetch}")?;
115                }
116                write!(f, "skip={}", self.skip)
117            }
118        }
119    }
120}
121
122impl ExecutionPlan for GlobalLimitExec {
123    fn name(&self) -> &'static str {
124        "GlobalLimitExec"
125    }
126
127    /// Return a reference to Any that can be used for downcasting
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn properties(&self) -> &PlanProperties {
133        &self.cache
134    }
135
136    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
137        vec![&self.input]
138    }
139
140    fn required_input_distribution(&self) -> Vec<Distribution> {
141        vec![Distribution::SinglePartition]
142    }
143
144    fn maintains_input_order(&self) -> Vec<bool> {
145        vec![true]
146    }
147
148    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
149        vec![false]
150    }
151
152    fn with_new_children(
153        self: Arc<Self>,
154        children: Vec<Arc<dyn ExecutionPlan>>,
155    ) -> Result<Arc<dyn ExecutionPlan>> {
156        Ok(Arc::new(GlobalLimitExec::new(
157            Arc::clone(&children[0]),
158            self.skip,
159            self.fetch,
160        )))
161    }
162
163    fn execute(
164        &self,
165        partition: usize,
166        context: Arc<TaskContext>,
167    ) -> Result<SendableRecordBatchStream> {
168        trace!("Start GlobalLimitExec::execute for partition: {partition}");
169        // GlobalLimitExec has a single output partition
170        if 0 != partition {
171            return internal_err!("GlobalLimitExec invalid partition {partition}");
172        }
173
174        // GlobalLimitExec requires a single input partition
175        if 1 != self.input.output_partitioning().partition_count() {
176            return internal_err!("GlobalLimitExec requires a single input partition");
177        }
178
179        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
180        let stream = self.input.execute(0, context)?;
181        Ok(Box::pin(LimitStream::new(
182            stream,
183            self.skip,
184            self.fetch,
185            baseline_metrics,
186        )))
187    }
188
189    fn metrics(&self) -> Option<MetricsSet> {
190        Some(self.metrics.clone_inner())
191    }
192
193    fn statistics(&self) -> Result<Statistics> {
194        self.partition_statistics(None)
195    }
196
197    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
198        self.input
199            .partition_statistics(partition)?
200            .with_fetch(self.fetch, self.skip, 1)
201    }
202
203    fn fetch(&self) -> Option<usize> {
204        self.fetch
205    }
206
207    fn supports_limit_pushdown(&self) -> bool {
208        true
209    }
210}
211
212/// LocalLimitExec applies a limit to a single partition
213#[derive(Debug)]
214pub struct LocalLimitExec {
215    /// Input execution plan
216    input: Arc<dyn ExecutionPlan>,
217    /// Maximum number of rows to return
218    fetch: usize,
219    /// Execution metrics
220    metrics: ExecutionPlanMetricsSet,
221    cache: PlanProperties,
222}
223
224impl LocalLimitExec {
225    /// Create a new LocalLimitExec partition
226    pub fn new(input: Arc<dyn ExecutionPlan>, fetch: usize) -> Self {
227        let cache = Self::compute_properties(&input);
228        Self {
229            input,
230            fetch,
231            metrics: ExecutionPlanMetricsSet::new(),
232            cache,
233        }
234    }
235
236    /// Input execution plan
237    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
238        &self.input
239    }
240
241    /// Maximum number of rows to fetch
242    pub fn fetch(&self) -> usize {
243        self.fetch
244    }
245
246    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
247    fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
248        PlanProperties::new(
249            input.equivalence_properties().clone(), // Equivalence Properties
250            input.output_partitioning().clone(),    // Output Partitioning
251            input.pipeline_behavior(),
252            // Limit operations are always bounded since they output a finite number of rows
253            Boundedness::Bounded,
254        )
255    }
256}
257
258impl DisplayAs for LocalLimitExec {
259    fn fmt_as(
260        &self,
261        t: DisplayFormatType,
262        f: &mut std::fmt::Formatter,
263    ) -> std::fmt::Result {
264        match t {
265            DisplayFormatType::Default | DisplayFormatType::Verbose => {
266                write!(f, "LocalLimitExec: fetch={}", self.fetch)
267            }
268            DisplayFormatType::TreeRender => {
269                write!(f, "limit={}", self.fetch)
270            }
271        }
272    }
273}
274
275impl ExecutionPlan for LocalLimitExec {
276    fn name(&self) -> &'static str {
277        "LocalLimitExec"
278    }
279
280    /// Return a reference to Any that can be used for downcasting
281    fn as_any(&self) -> &dyn Any {
282        self
283    }
284
285    fn properties(&self) -> &PlanProperties {
286        &self.cache
287    }
288
289    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
290        vec![&self.input]
291    }
292
293    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
294        vec![false]
295    }
296
297    fn maintains_input_order(&self) -> Vec<bool> {
298        vec![true]
299    }
300
301    fn with_new_children(
302        self: Arc<Self>,
303        children: Vec<Arc<dyn ExecutionPlan>>,
304    ) -> Result<Arc<dyn ExecutionPlan>> {
305        match children.len() {
306            1 => Ok(Arc::new(LocalLimitExec::new(
307                Arc::clone(&children[0]),
308                self.fetch,
309            ))),
310            _ => internal_err!("LocalLimitExec wrong number of children"),
311        }
312    }
313
314    fn execute(
315        &self,
316        partition: usize,
317        context: Arc<TaskContext>,
318    ) -> Result<SendableRecordBatchStream> {
319        trace!("Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
320        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
321        let stream = self.input.execute(partition, context)?;
322        Ok(Box::pin(LimitStream::new(
323            stream,
324            0,
325            Some(self.fetch),
326            baseline_metrics,
327        )))
328    }
329
330    fn metrics(&self) -> Option<MetricsSet> {
331        Some(self.metrics.clone_inner())
332    }
333
334    fn statistics(&self) -> Result<Statistics> {
335        self.partition_statistics(None)
336    }
337
338    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
339        self.input
340            .partition_statistics(partition)?
341            .with_fetch(Some(self.fetch), 0, 1)
342    }
343
344    fn fetch(&self) -> Option<usize> {
345        Some(self.fetch)
346    }
347
348    fn supports_limit_pushdown(&self) -> bool {
349        true
350    }
351
352    fn cardinality_effect(&self) -> CardinalityEffect {
353        CardinalityEffect::LowerEqual
354    }
355}
356
357/// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows.
358pub struct LimitStream {
359    /// The remaining number of rows to skip
360    skip: usize,
361    /// The remaining number of rows to produce
362    fetch: usize,
363    /// The input to read from. This is set to None once the limit is
364    /// reached to enable early termination
365    input: Option<SendableRecordBatchStream>,
366    /// Copy of the input schema
367    schema: SchemaRef,
368    /// Execution time metrics
369    baseline_metrics: BaselineMetrics,
370}
371
372impl LimitStream {
373    pub fn new(
374        input: SendableRecordBatchStream,
375        skip: usize,
376        fetch: Option<usize>,
377        baseline_metrics: BaselineMetrics,
378    ) -> Self {
379        let schema = input.schema();
380        Self {
381            skip,
382            fetch: fetch.unwrap_or(usize::MAX),
383            input: Some(input),
384            schema,
385            baseline_metrics,
386        }
387    }
388
389    fn poll_and_skip(
390        &mut self,
391        cx: &mut Context<'_>,
392    ) -> Poll<Option<Result<RecordBatch>>> {
393        let input = self.input.as_mut().unwrap();
394        loop {
395            let poll = input.poll_next_unpin(cx);
396            let poll = poll.map_ok(|batch| {
397                if batch.num_rows() <= self.skip {
398                    self.skip -= batch.num_rows();
399                    RecordBatch::new_empty(input.schema())
400                } else {
401                    let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip);
402                    self.skip = 0;
403                    new_batch
404                }
405            });
406
407            match &poll {
408                Poll::Ready(Some(Ok(batch))) => {
409                    if batch.num_rows() > 0 {
410                        break poll;
411                    } else {
412                        // Continue to poll input stream
413                    }
414                }
415                Poll::Ready(Some(Err(_e))) => break poll,
416                Poll::Ready(None) => break poll,
417                Poll::Pending => break poll,
418            }
419        }
420    }
421
422    /// Fetches from the batch
423    fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
424        // records time on drop
425        let _timer = self.baseline_metrics.elapsed_compute().timer();
426        if self.fetch == 0 {
427            self.input = None; // Clear input so it can be dropped early
428            None
429        } else if batch.num_rows() < self.fetch {
430            //
431            self.fetch -= batch.num_rows();
432            Some(batch)
433        } else if batch.num_rows() >= self.fetch {
434            let batch_rows = self.fetch;
435            self.fetch = 0;
436            self.input = None; // Clear input so it can be dropped early
437
438            // It is guaranteed that batch_rows is <= batch.num_rows
439            Some(batch.slice(0, batch_rows))
440        } else {
441            unreachable!()
442        }
443    }
444}
445
446impl Stream for LimitStream {
447    type Item = Result<RecordBatch>;
448
449    fn poll_next(
450        mut self: Pin<&mut Self>,
451        cx: &mut Context<'_>,
452    ) -> Poll<Option<Self::Item>> {
453        let fetch_started = self.skip == 0;
454        let poll = match &mut self.input {
455            Some(input) => {
456                let poll = if fetch_started {
457                    input.poll_next_unpin(cx)
458                } else {
459                    self.poll_and_skip(cx)
460                };
461
462                poll.map(|x| match x {
463                    Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
464                    other => other,
465                })
466            }
467            // Input has been cleared
468            None => Poll::Ready(None),
469        };
470
471        self.baseline_metrics.record_poll(poll)
472    }
473}
474
475impl RecordBatchStream for LimitStream {
476    /// Get the schema
477    fn schema(&self) -> SchemaRef {
478        Arc::clone(&self.schema)
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485    use crate::coalesce_partitions::CoalescePartitionsExec;
486    use crate::common::collect;
487    use crate::test;
488
489    use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
490    use arrow::array::RecordBatchOptions;
491    use arrow::datatypes::Schema;
492    use datafusion_common::stats::Precision;
493    use datafusion_physical_expr::expressions::col;
494    use datafusion_physical_expr::PhysicalExpr;
495
496    #[tokio::test]
497    async fn limit() -> Result<()> {
498        let task_ctx = Arc::new(TaskContext::default());
499
500        let num_partitions = 4;
501        let csv = test::scan_partitioned(num_partitions);
502
503        // Input should have 4 partitions
504        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
505
506        let limit =
507            GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));
508
509        // The result should contain 4 batches (one per input partition)
510        let iter = limit.execute(0, task_ctx)?;
511        let batches = collect(iter).await?;
512
513        // There should be a total of 100 rows
514        let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
515        assert_eq!(row_count, 7);
516
517        Ok(())
518    }
519
520    #[tokio::test]
521    async fn limit_early_shutdown() -> Result<()> {
522        let batches = vec![
523            test::make_partition(5),
524            test::make_partition(10),
525            test::make_partition(15),
526            test::make_partition(20),
527            test::make_partition(25),
528        ];
529        let input = test::exec::TestStream::new(batches);
530
531        let index = input.index();
532        assert_eq!(index.value(), 0);
533
534        // Limit of six needs to consume the entire first record batch
535        // (5 rows) and 1 row from the second (1 row)
536        let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
537        let limit_stream =
538            LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
539        assert_eq!(index.value(), 0);
540
541        let results = collect(Box::pin(limit_stream)).await.unwrap();
542        let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
543        // Only 6 rows should have been produced
544        assert_eq!(num_rows, 6);
545
546        // Only the first two batches should be consumed
547        assert_eq!(index.value(), 2);
548
549        Ok(())
550    }
551
552    #[tokio::test]
553    async fn limit_equals_batch_size() -> Result<()> {
554        let batches = vec![
555            test::make_partition(6),
556            test::make_partition(6),
557            test::make_partition(6),
558        ];
559        let input = test::exec::TestStream::new(batches);
560
561        let index = input.index();
562        assert_eq!(index.value(), 0);
563
564        // Limit of six needs to consume the entire first record batch
565        // (6 rows) and stop immediately
566        let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
567        let limit_stream =
568            LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
569        assert_eq!(index.value(), 0);
570
571        let results = collect(Box::pin(limit_stream)).await.unwrap();
572        let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
573        // Only 6 rows should have been produced
574        assert_eq!(num_rows, 6);
575
576        // Only the first batch should be consumed
577        assert_eq!(index.value(), 1);
578
579        Ok(())
580    }
581
582    #[tokio::test]
583    async fn limit_no_column() -> Result<()> {
584        let batches = vec![
585            make_batch_no_column(6),
586            make_batch_no_column(6),
587            make_batch_no_column(6),
588        ];
589        let input = test::exec::TestStream::new(batches);
590
591        let index = input.index();
592        assert_eq!(index.value(), 0);
593
594        // Limit of six needs to consume the entire first record batch
595        // (6 rows) and stop immediately
596        let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
597        let limit_stream =
598            LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
599        assert_eq!(index.value(), 0);
600
601        let results = collect(Box::pin(limit_stream)).await.unwrap();
602        let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
603        // Only 6 rows should have been produced
604        assert_eq!(num_rows, 6);
605
606        // Only the first batch should be consumed
607        assert_eq!(index.value(), 1);
608
609        Ok(())
610    }
611
612    // Test cases for "skip"
613    async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
614        let task_ctx = Arc::new(TaskContext::default());
615
616        // 4 partitions @ 100 rows apiece
617        let num_partitions = 4;
618        let csv = test::scan_partitioned(num_partitions);
619
620        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
621
622        let offset =
623            GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
624
625        // The result should contain 4 batches (one per input partition)
626        let iter = offset.execute(0, task_ctx)?;
627        let batches = collect(iter).await?;
628        Ok(batches.iter().map(|batch| batch.num_rows()).sum())
629    }
630
631    #[tokio::test]
632    async fn skip_none_fetch_none() -> Result<()> {
633        let row_count = skip_and_fetch(0, None).await?;
634        assert_eq!(row_count, 400);
635        Ok(())
636    }
637
638    #[tokio::test]
639    async fn skip_none_fetch_50() -> Result<()> {
640        let row_count = skip_and_fetch(0, Some(50)).await?;
641        assert_eq!(row_count, 50);
642        Ok(())
643    }
644
645    #[tokio::test]
646    async fn skip_3_fetch_none() -> Result<()> {
647        // There are total of 400 rows, we skipped 3 rows (offset = 3)
648        let row_count = skip_and_fetch(3, None).await?;
649        assert_eq!(row_count, 397);
650        Ok(())
651    }
652
653    #[tokio::test]
654    async fn skip_3_fetch_10_stats() -> Result<()> {
655        // There are total of 100 rows, we skipped 3 rows (offset = 3)
656        let row_count = skip_and_fetch(3, Some(10)).await?;
657        assert_eq!(row_count, 10);
658        Ok(())
659    }
660
661    #[tokio::test]
662    async fn skip_400_fetch_none() -> Result<()> {
663        let row_count = skip_and_fetch(400, None).await?;
664        assert_eq!(row_count, 0);
665        Ok(())
666    }
667
668    #[tokio::test]
669    async fn skip_400_fetch_1() -> Result<()> {
670        // There are a total of 400 rows
671        let row_count = skip_and_fetch(400, Some(1)).await?;
672        assert_eq!(row_count, 0);
673        Ok(())
674    }
675
676    #[tokio::test]
677    async fn skip_401_fetch_none() -> Result<()> {
678        // There are total of 400 rows, we skipped 401 rows (offset = 3)
679        let row_count = skip_and_fetch(401, None).await?;
680        assert_eq!(row_count, 0);
681        Ok(())
682    }
683
684    #[tokio::test]
685    async fn test_row_number_statistics_for_global_limit() -> Result<()> {
686        let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?;
687        assert_eq!(row_count, Precision::Exact(10));
688
689        let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
690        assert_eq!(row_count, Precision::Exact(10));
691
692        let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
693        assert_eq!(row_count, Precision::Exact(0));
694
695        let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
696        assert_eq!(row_count, Precision::Exact(2));
697
698        let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
699        assert_eq!(row_count, Precision::Exact(1));
700
701        let row_count = row_number_statistics_for_global_limit(398, None).await?;
702        assert_eq!(row_count, Precision::Exact(2));
703
704        let row_count =
705            row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
706        assert_eq!(row_count, Precision::Exact(400));
707
708        let row_count =
709            row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
710        assert_eq!(row_count, Precision::Exact(2));
711
712        let row_count =
713            row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
714        assert_eq!(row_count, Precision::Inexact(10));
715
716        let row_count =
717            row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
718        assert_eq!(row_count, Precision::Inexact(10));
719
720        let row_count =
721            row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
722        assert_eq!(row_count, Precision::Exact(0));
723
724        let row_count =
725            row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
726        assert_eq!(row_count, Precision::Inexact(2));
727
728        let row_count =
729            row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
730        assert_eq!(row_count, Precision::Inexact(1));
731
732        let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
733        assert_eq!(row_count, Precision::Inexact(2));
734
735        let row_count =
736            row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
737        assert_eq!(row_count, Precision::Inexact(400));
738
739        let row_count =
740            row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
741        assert_eq!(row_count, Precision::Inexact(2));
742
743        Ok(())
744    }
745
746    #[tokio::test]
747    async fn test_row_number_statistics_for_local_limit() -> Result<()> {
748        let row_count = row_number_statistics_for_local_limit(4, 10).await?;
749        assert_eq!(row_count, Precision::Exact(10));
750
751        Ok(())
752    }
753
754    async fn row_number_statistics_for_global_limit(
755        skip: usize,
756        fetch: Option<usize>,
757    ) -> Result<Precision<usize>> {
758        let num_partitions = 4;
759        let csv = test::scan_partitioned(num_partitions);
760
761        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
762
763        let offset =
764            GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
765
766        Ok(offset.partition_statistics(None)?.num_rows)
767    }
768
769    pub fn build_group_by(
770        input_schema: &SchemaRef,
771        columns: Vec<String>,
772    ) -> PhysicalGroupBy {
773        let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
774        for column in columns.iter() {
775            group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
776        }
777        PhysicalGroupBy::new_single(group_by_expr.clone())
778    }
779
780    async fn row_number_inexact_statistics_for_global_limit(
781        skip: usize,
782        fetch: Option<usize>,
783    ) -> Result<Precision<usize>> {
784        let num_partitions = 4;
785        let csv = test::scan_partitioned(num_partitions);
786
787        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
788
789        // Adding a "GROUP BY i" changes the input stats from Exact to Inexact.
790        let agg = AggregateExec::try_new(
791            AggregateMode::Final,
792            build_group_by(&csv.schema(), vec!["i".to_string()]),
793            vec![],
794            vec![],
795            Arc::clone(&csv),
796            Arc::clone(&csv.schema()),
797        )?;
798        let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
799
800        let offset = GlobalLimitExec::new(
801            Arc::new(CoalescePartitionsExec::new(agg_exec)),
802            skip,
803            fetch,
804        );
805
806        Ok(offset.partition_statistics(None)?.num_rows)
807    }
808
809    async fn row_number_statistics_for_local_limit(
810        num_partitions: usize,
811        fetch: usize,
812    ) -> Result<Precision<usize>> {
813        let csv = test::scan_partitioned(num_partitions);
814
815        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
816
817        let offset = LocalLimitExec::new(csv, fetch);
818
819        Ok(offset.partition_statistics(None)?.num_rows)
820    }
821
822    /// Return a RecordBatch with a single array with row_count sz
823    fn make_batch_no_column(sz: usize) -> RecordBatch {
824        let schema = Arc::new(Schema::empty());
825
826        let options = RecordBatchOptions::new().with_row_count(Option::from(sz));
827        RecordBatch::try_new_with_options(schema, vec![], &options).unwrap()
828    }
829}