Skip to main content

datafusion_physical_plan/
limit.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines the LIMIT plan
19
20use std::any::Any;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use super::{
27    DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
28    SendableRecordBatchStream, Statistics,
29};
30use crate::execution_plan::{Boundedness, CardinalityEffect};
31use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning};
32
33use arrow::datatypes::SchemaRef;
34use arrow::record_batch::RecordBatch;
35use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
36use datafusion_execution::TaskContext;
37
38use futures::stream::{Stream, StreamExt};
39use log::trace;
40
41/// Limit execution plan
42#[derive(Debug, Clone)]
43pub struct GlobalLimitExec {
44    /// Input execution plan
45    input: Arc<dyn ExecutionPlan>,
46    /// Number of rows to skip before fetch
47    skip: usize,
48    /// Maximum number of rows to fetch,
49    /// `None` means fetching all rows
50    fetch: Option<usize>,
51    /// Execution metrics
52    metrics: ExecutionPlanMetricsSet,
53    cache: PlanProperties,
54}
55
56impl GlobalLimitExec {
57    /// Create a new GlobalLimitExec
58    pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
59        let cache = Self::compute_properties(&input);
60        GlobalLimitExec {
61            input,
62            skip,
63            fetch,
64            metrics: ExecutionPlanMetricsSet::new(),
65            cache,
66        }
67    }
68
69    /// Input execution plan
70    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
71        &self.input
72    }
73
74    /// Number of rows to skip before fetch
75    pub fn skip(&self) -> usize {
76        self.skip
77    }
78
79    /// Maximum number of rows to fetch
80    pub fn fetch(&self) -> Option<usize> {
81        self.fetch
82    }
83
84    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
85    fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
86        PlanProperties::new(
87            input.equivalence_properties().clone(), // Equivalence Properties
88            Partitioning::UnknownPartitioning(1),   // Output Partitioning
89            input.pipeline_behavior(),
90            // Limit operations are always bounded since they output a finite number of rows
91            Boundedness::Bounded,
92        )
93    }
94}
95
96impl DisplayAs for GlobalLimitExec {
97    fn fmt_as(
98        &self,
99        t: DisplayFormatType,
100        f: &mut std::fmt::Formatter,
101    ) -> std::fmt::Result {
102        match t {
103            DisplayFormatType::Default | DisplayFormatType::Verbose => {
104                write!(
105                    f,
106                    "GlobalLimitExec: skip={}, fetch={}",
107                    self.skip,
108                    self.fetch
109                        .map_or_else(|| "None".to_string(), |x| x.to_string())
110                )
111            }
112            DisplayFormatType::TreeRender => {
113                if let Some(fetch) = self.fetch {
114                    writeln!(f, "limit={fetch}")?;
115                }
116                write!(f, "skip={}", self.skip)
117            }
118        }
119    }
120}
121
122impl ExecutionPlan for GlobalLimitExec {
123    fn name(&self) -> &'static str {
124        "GlobalLimitExec"
125    }
126
127    /// Return a reference to Any that can be used for downcasting
128    fn as_any(&self) -> &dyn Any {
129        self
130    }
131
132    fn properties(&self) -> &PlanProperties {
133        &self.cache
134    }
135
136    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
137        vec![&self.input]
138    }
139
140    fn required_input_distribution(&self) -> Vec<Distribution> {
141        vec![Distribution::SinglePartition]
142    }
143
144    fn maintains_input_order(&self) -> Vec<bool> {
145        vec![true]
146    }
147
148    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
149        vec![false]
150    }
151
152    fn with_new_children(
153        self: Arc<Self>,
154        children: Vec<Arc<dyn ExecutionPlan>>,
155    ) -> Result<Arc<dyn ExecutionPlan>> {
156        Ok(Arc::new(GlobalLimitExec::new(
157            Arc::clone(&children[0]),
158            self.skip,
159            self.fetch,
160        )))
161    }
162
163    fn execute(
164        &self,
165        partition: usize,
166        context: Arc<TaskContext>,
167    ) -> Result<SendableRecordBatchStream> {
168        trace!("Start GlobalLimitExec::execute for partition: {partition}");
169        // GlobalLimitExec has a single output partition
170        assert_eq_or_internal_err!(
171            partition,
172            0,
173            "GlobalLimitExec invalid partition {partition}"
174        );
175
176        // GlobalLimitExec requires a single input partition
177        assert_eq_or_internal_err!(
178            self.input.output_partitioning().partition_count(),
179            1,
180            "GlobalLimitExec requires a single input partition"
181        );
182
183        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
184        let stream = self.input.execute(0, context)?;
185        Ok(Box::pin(LimitStream::new(
186            stream,
187            self.skip,
188            self.fetch,
189            baseline_metrics,
190        )))
191    }
192
193    fn metrics(&self) -> Option<MetricsSet> {
194        Some(self.metrics.clone_inner())
195    }
196
197    fn statistics(&self) -> Result<Statistics> {
198        self.partition_statistics(None)
199    }
200
201    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
202        self.input
203            .partition_statistics(partition)?
204            .with_fetch(self.fetch, self.skip, 1)
205    }
206
207    fn fetch(&self) -> Option<usize> {
208        self.fetch
209    }
210
211    fn supports_limit_pushdown(&self) -> bool {
212        true
213    }
214}
215
216/// LocalLimitExec applies a limit to a single partition
217#[derive(Debug)]
218pub struct LocalLimitExec {
219    /// Input execution plan
220    input: Arc<dyn ExecutionPlan>,
221    /// Maximum number of rows to return
222    fetch: usize,
223    /// Execution metrics
224    metrics: ExecutionPlanMetricsSet,
225    cache: PlanProperties,
226}
227
228impl LocalLimitExec {
229    /// Create a new LocalLimitExec partition
230    pub fn new(input: Arc<dyn ExecutionPlan>, fetch: usize) -> Self {
231        let cache = Self::compute_properties(&input);
232        Self {
233            input,
234            fetch,
235            metrics: ExecutionPlanMetricsSet::new(),
236            cache,
237        }
238    }
239
240    /// Input execution plan
241    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
242        &self.input
243    }
244
245    /// Maximum number of rows to fetch
246    pub fn fetch(&self) -> usize {
247        self.fetch
248    }
249
250    /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc.
251    fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
252        PlanProperties::new(
253            input.equivalence_properties().clone(), // Equivalence Properties
254            input.output_partitioning().clone(),    // Output Partitioning
255            input.pipeline_behavior(),
256            // Limit operations are always bounded since they output a finite number of rows
257            Boundedness::Bounded,
258        )
259    }
260}
261
262impl DisplayAs for LocalLimitExec {
263    fn fmt_as(
264        &self,
265        t: DisplayFormatType,
266        f: &mut std::fmt::Formatter,
267    ) -> std::fmt::Result {
268        match t {
269            DisplayFormatType::Default | DisplayFormatType::Verbose => {
270                write!(f, "LocalLimitExec: fetch={}", self.fetch)
271            }
272            DisplayFormatType::TreeRender => {
273                write!(f, "limit={}", self.fetch)
274            }
275        }
276    }
277}
278
279impl ExecutionPlan for LocalLimitExec {
280    fn name(&self) -> &'static str {
281        "LocalLimitExec"
282    }
283
284    /// Return a reference to Any that can be used for downcasting
285    fn as_any(&self) -> &dyn Any {
286        self
287    }
288
289    fn properties(&self) -> &PlanProperties {
290        &self.cache
291    }
292
293    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
294        vec![&self.input]
295    }
296
297    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
298        vec![false]
299    }
300
301    fn maintains_input_order(&self) -> Vec<bool> {
302        vec![true]
303    }
304
305    fn with_new_children(
306        self: Arc<Self>,
307        children: Vec<Arc<dyn ExecutionPlan>>,
308    ) -> Result<Arc<dyn ExecutionPlan>> {
309        match children.len() {
310            1 => Ok(Arc::new(LocalLimitExec::new(
311                Arc::clone(&children[0]),
312                self.fetch,
313            ))),
314            _ => internal_err!("LocalLimitExec wrong number of children"),
315        }
316    }
317
318    fn execute(
319        &self,
320        partition: usize,
321        context: Arc<TaskContext>,
322    ) -> Result<SendableRecordBatchStream> {
323        trace!(
324            "Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}",
325            partition,
326            context.session_id(),
327            context.task_id()
328        );
329        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
330        let stream = self.input.execute(partition, context)?;
331        Ok(Box::pin(LimitStream::new(
332            stream,
333            0,
334            Some(self.fetch),
335            baseline_metrics,
336        )))
337    }
338
339    fn metrics(&self) -> Option<MetricsSet> {
340        Some(self.metrics.clone_inner())
341    }
342
343    fn statistics(&self) -> Result<Statistics> {
344        self.partition_statistics(None)
345    }
346
347    fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
348        self.input
349            .partition_statistics(partition)?
350            .with_fetch(Some(self.fetch), 0, 1)
351    }
352
353    fn fetch(&self) -> Option<usize> {
354        Some(self.fetch)
355    }
356
357    fn supports_limit_pushdown(&self) -> bool {
358        true
359    }
360
361    fn cardinality_effect(&self) -> CardinalityEffect {
362        CardinalityEffect::LowerEqual
363    }
364}
365
366/// A Limit stream skips `skip` rows, and then fetch up to `fetch` rows.
367pub struct LimitStream {
368    /// The remaining number of rows to skip
369    skip: usize,
370    /// The remaining number of rows to produce
371    fetch: usize,
372    /// The input to read from. This is set to None once the limit is
373    /// reached to enable early termination
374    input: Option<SendableRecordBatchStream>,
375    /// Copy of the input schema
376    schema: SchemaRef,
377    /// Execution time metrics
378    baseline_metrics: BaselineMetrics,
379}
380
381impl LimitStream {
382    pub fn new(
383        input: SendableRecordBatchStream,
384        skip: usize,
385        fetch: Option<usize>,
386        baseline_metrics: BaselineMetrics,
387    ) -> Self {
388        let schema = input.schema();
389        Self {
390            skip,
391            fetch: fetch.unwrap_or(usize::MAX),
392            input: Some(input),
393            schema,
394            baseline_metrics,
395        }
396    }
397
398    fn poll_and_skip(
399        &mut self,
400        cx: &mut Context<'_>,
401    ) -> Poll<Option<Result<RecordBatch>>> {
402        let input = self.input.as_mut().unwrap();
403        loop {
404            let poll = input.poll_next_unpin(cx);
405            let poll = poll.map_ok(|batch| {
406                if batch.num_rows() <= self.skip {
407                    self.skip -= batch.num_rows();
408                    RecordBatch::new_empty(input.schema())
409                } else {
410                    let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip);
411                    self.skip = 0;
412                    new_batch
413                }
414            });
415
416            match &poll {
417                Poll::Ready(Some(Ok(batch))) => {
418                    if batch.num_rows() > 0 {
419                        break poll;
420                    } else {
421                        // Continue to poll input stream
422                    }
423                }
424                Poll::Ready(Some(Err(_e))) => break poll,
425                Poll::Ready(None) => break poll,
426                Poll::Pending => break poll,
427            }
428        }
429    }
430
431    /// Fetches from the batch
432    fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
433        // records time on drop
434        let _timer = self.baseline_metrics.elapsed_compute().timer();
435        if self.fetch == 0 {
436            self.input = None; // Clear input so it can be dropped early
437            None
438        } else if batch.num_rows() < self.fetch {
439            //
440            self.fetch -= batch.num_rows();
441            Some(batch)
442        } else if batch.num_rows() >= self.fetch {
443            let batch_rows = self.fetch;
444            self.fetch = 0;
445            self.input = None; // Clear input so it can be dropped early
446
447            // It is guaranteed that batch_rows is <= batch.num_rows
448            Some(batch.slice(0, batch_rows))
449        } else {
450            unreachable!()
451        }
452    }
453}
454
455impl Stream for LimitStream {
456    type Item = Result<RecordBatch>;
457
458    fn poll_next(
459        mut self: Pin<&mut Self>,
460        cx: &mut Context<'_>,
461    ) -> Poll<Option<Self::Item>> {
462        let fetch_started = self.skip == 0;
463        let poll = match &mut self.input {
464            Some(input) => {
465                let poll = if fetch_started {
466                    input.poll_next_unpin(cx)
467                } else {
468                    self.poll_and_skip(cx)
469                };
470
471                poll.map(|x| match x {
472                    Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
473                    other => other,
474                })
475            }
476            // Input has been cleared
477            None => Poll::Ready(None),
478        };
479
480        self.baseline_metrics.record_poll(poll)
481    }
482}
483
484impl RecordBatchStream for LimitStream {
485    /// Get the schema
486    fn schema(&self) -> SchemaRef {
487        Arc::clone(&self.schema)
488    }
489}
490
491#[cfg(test)]
492mod tests {
493    use super::*;
494    use crate::coalesce_partitions::CoalescePartitionsExec;
495    use crate::common::collect;
496    use crate::test;
497
498    use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
499    use arrow::array::RecordBatchOptions;
500    use arrow::datatypes::Schema;
501    use datafusion_common::stats::Precision;
502    use datafusion_physical_expr::PhysicalExpr;
503    use datafusion_physical_expr::expressions::col;
504
505    #[tokio::test]
506    async fn limit() -> Result<()> {
507        let task_ctx = Arc::new(TaskContext::default());
508
509        let num_partitions = 4;
510        let csv = test::scan_partitioned(num_partitions);
511
512        // Input should have 4 partitions
513        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
514
515        let limit =
516            GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));
517
518        // The result should contain 4 batches (one per input partition)
519        let iter = limit.execute(0, task_ctx)?;
520        let batches = collect(iter).await?;
521
522        // There should be a total of 100 rows
523        let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
524        assert_eq!(row_count, 7);
525
526        Ok(())
527    }
528
529    #[tokio::test]
530    async fn limit_early_shutdown() -> Result<()> {
531        let batches = vec![
532            test::make_partition(5),
533            test::make_partition(10),
534            test::make_partition(15),
535            test::make_partition(20),
536            test::make_partition(25),
537        ];
538        let input = test::exec::TestStream::new(batches);
539
540        let index = input.index();
541        assert_eq!(index.value(), 0);
542
543        // Limit of six needs to consume the entire first record batch
544        // (5 rows) and 1 row from the second (1 row)
545        let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
546        let limit_stream =
547            LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
548        assert_eq!(index.value(), 0);
549
550        let results = collect(Box::pin(limit_stream)).await.unwrap();
551        let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
552        // Only 6 rows should have been produced
553        assert_eq!(num_rows, 6);
554
555        // Only the first two batches should be consumed
556        assert_eq!(index.value(), 2);
557
558        Ok(())
559    }
560
561    #[tokio::test]
562    async fn limit_equals_batch_size() -> Result<()> {
563        let batches = vec![
564            test::make_partition(6),
565            test::make_partition(6),
566            test::make_partition(6),
567        ];
568        let input = test::exec::TestStream::new(batches);
569
570        let index = input.index();
571        assert_eq!(index.value(), 0);
572
573        // Limit of six needs to consume the entire first record batch
574        // (6 rows) and stop immediately
575        let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
576        let limit_stream =
577            LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
578        assert_eq!(index.value(), 0);
579
580        let results = collect(Box::pin(limit_stream)).await.unwrap();
581        let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
582        // Only 6 rows should have been produced
583        assert_eq!(num_rows, 6);
584
585        // Only the first batch should be consumed
586        assert_eq!(index.value(), 1);
587
588        Ok(())
589    }
590
591    #[tokio::test]
592    async fn limit_no_column() -> Result<()> {
593        let batches = vec![
594            make_batch_no_column(6),
595            make_batch_no_column(6),
596            make_batch_no_column(6),
597        ];
598        let input = test::exec::TestStream::new(batches);
599
600        let index = input.index();
601        assert_eq!(index.value(), 0);
602
603        // Limit of six needs to consume the entire first record batch
604        // (6 rows) and stop immediately
605        let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
606        let limit_stream =
607            LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
608        assert_eq!(index.value(), 0);
609
610        let results = collect(Box::pin(limit_stream)).await.unwrap();
611        let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
612        // Only 6 rows should have been produced
613        assert_eq!(num_rows, 6);
614
615        // Only the first batch should be consumed
616        assert_eq!(index.value(), 1);
617
618        Ok(())
619    }
620
621    // Test cases for "skip"
622    async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
623        let task_ctx = Arc::new(TaskContext::default());
624
625        // 4 partitions @ 100 rows apiece
626        let num_partitions = 4;
627        let csv = test::scan_partitioned(num_partitions);
628
629        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
630
631        let offset =
632            GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
633
634        // The result should contain 4 batches (one per input partition)
635        let iter = offset.execute(0, task_ctx)?;
636        let batches = collect(iter).await?;
637        Ok(batches.iter().map(|batch| batch.num_rows()).sum())
638    }
639
640    #[tokio::test]
641    async fn skip_none_fetch_none() -> Result<()> {
642        let row_count = skip_and_fetch(0, None).await?;
643        assert_eq!(row_count, 400);
644        Ok(())
645    }
646
647    #[tokio::test]
648    async fn skip_none_fetch_50() -> Result<()> {
649        let row_count = skip_and_fetch(0, Some(50)).await?;
650        assert_eq!(row_count, 50);
651        Ok(())
652    }
653
654    #[tokio::test]
655    async fn skip_3_fetch_none() -> Result<()> {
656        // There are total of 400 rows, we skipped 3 rows (offset = 3)
657        let row_count = skip_and_fetch(3, None).await?;
658        assert_eq!(row_count, 397);
659        Ok(())
660    }
661
662    #[tokio::test]
663    async fn skip_3_fetch_10_stats() -> Result<()> {
664        // There are total of 100 rows, we skipped 3 rows (offset = 3)
665        let row_count = skip_and_fetch(3, Some(10)).await?;
666        assert_eq!(row_count, 10);
667        Ok(())
668    }
669
670    #[tokio::test]
671    async fn skip_400_fetch_none() -> Result<()> {
672        let row_count = skip_and_fetch(400, None).await?;
673        assert_eq!(row_count, 0);
674        Ok(())
675    }
676
677    #[tokio::test]
678    async fn skip_400_fetch_1() -> Result<()> {
679        // There are a total of 400 rows
680        let row_count = skip_and_fetch(400, Some(1)).await?;
681        assert_eq!(row_count, 0);
682        Ok(())
683    }
684
685    #[tokio::test]
686    async fn skip_401_fetch_none() -> Result<()> {
687        // There are total of 400 rows, we skipped 401 rows (offset = 3)
688        let row_count = skip_and_fetch(401, None).await?;
689        assert_eq!(row_count, 0);
690        Ok(())
691    }
692
693    #[tokio::test]
694    async fn test_row_number_statistics_for_global_limit() -> Result<()> {
695        let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?;
696        assert_eq!(row_count, Precision::Exact(10));
697
698        let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
699        assert_eq!(row_count, Precision::Exact(10));
700
701        let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
702        assert_eq!(row_count, Precision::Exact(0));
703
704        let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
705        assert_eq!(row_count, Precision::Exact(2));
706
707        let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
708        assert_eq!(row_count, Precision::Exact(1));
709
710        let row_count = row_number_statistics_for_global_limit(398, None).await?;
711        assert_eq!(row_count, Precision::Exact(2));
712
713        let row_count =
714            row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
715        assert_eq!(row_count, Precision::Exact(400));
716
717        let row_count =
718            row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
719        assert_eq!(row_count, Precision::Exact(2));
720
721        let row_count =
722            row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
723        assert_eq!(row_count, Precision::Inexact(10));
724
725        let row_count =
726            row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
727        assert_eq!(row_count, Precision::Inexact(10));
728
729        let row_count =
730            row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
731        assert_eq!(row_count, Precision::Exact(0));
732
733        let row_count =
734            row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
735        assert_eq!(row_count, Precision::Inexact(2));
736
737        let row_count =
738            row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
739        assert_eq!(row_count, Precision::Inexact(1));
740
741        let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
742        assert_eq!(row_count, Precision::Inexact(2));
743
744        let row_count =
745            row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
746        assert_eq!(row_count, Precision::Inexact(400));
747
748        let row_count =
749            row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
750        assert_eq!(row_count, Precision::Inexact(2));
751
752        Ok(())
753    }
754
755    #[tokio::test]
756    async fn test_row_number_statistics_for_local_limit() -> Result<()> {
757        let row_count = row_number_statistics_for_local_limit(4, 10).await?;
758        assert_eq!(row_count, Precision::Exact(10));
759
760        Ok(())
761    }
762
763    async fn row_number_statistics_for_global_limit(
764        skip: usize,
765        fetch: Option<usize>,
766    ) -> Result<Precision<usize>> {
767        let num_partitions = 4;
768        let csv = test::scan_partitioned(num_partitions);
769
770        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
771
772        let offset =
773            GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
774
775        Ok(offset.partition_statistics(None)?.num_rows)
776    }
777
778    pub fn build_group_by(
779        input_schema: &SchemaRef,
780        columns: Vec<String>,
781    ) -> PhysicalGroupBy {
782        let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
783        for column in columns.iter() {
784            group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
785        }
786        PhysicalGroupBy::new_single(group_by_expr.clone())
787    }
788
789    async fn row_number_inexact_statistics_for_global_limit(
790        skip: usize,
791        fetch: Option<usize>,
792    ) -> Result<Precision<usize>> {
793        let num_partitions = 4;
794        let csv = test::scan_partitioned(num_partitions);
795
796        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
797
798        // Adding a "GROUP BY i" changes the input stats from Exact to Inexact.
799        let agg = AggregateExec::try_new(
800            AggregateMode::Final,
801            build_group_by(&csv.schema(), vec!["i".to_string()]),
802            vec![],
803            vec![],
804            Arc::clone(&csv),
805            Arc::clone(&csv.schema()),
806        )?;
807        let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
808
809        let offset = GlobalLimitExec::new(
810            Arc::new(CoalescePartitionsExec::new(agg_exec)),
811            skip,
812            fetch,
813        );
814
815        Ok(offset.partition_statistics(None)?.num_rows)
816    }
817
818    async fn row_number_statistics_for_local_limit(
819        num_partitions: usize,
820        fetch: usize,
821    ) -> Result<Precision<usize>> {
822        let csv = test::scan_partitioned(num_partitions);
823
824        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
825
826        let offset = LocalLimitExec::new(csv, fetch);
827
828        Ok(offset.partition_statistics(None)?.num_rows)
829    }
830
831    /// Return a RecordBatch with a single array with row_count sz
832    fn make_batch_no_column(sz: usize) -> RecordBatch {
833        let schema = Arc::new(Schema::empty());
834
835        let options = RecordBatchOptions::new().with_row_count(Option::from(sz));
836        RecordBatch::try_new_with_options(schema, vec![], &options).unwrap()
837    }
838}