1use std::any::Any;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use super::{
27 DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
28 SendableRecordBatchStream, Statistics,
29};
30use crate::execution_plan::{Boundedness, CardinalityEffect};
31use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning};
32
33use arrow::datatypes::SchemaRef;
34use arrow::record_batch::RecordBatch;
35use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
36use datafusion_execution::TaskContext;
37
38use futures::stream::{Stream, StreamExt};
39use log::trace;
40
41#[derive(Debug, Clone)]
43pub struct GlobalLimitExec {
44 input: Arc<dyn ExecutionPlan>,
46 skip: usize,
48 fetch: Option<usize>,
51 metrics: ExecutionPlanMetricsSet,
53 cache: PlanProperties,
54}
55
56impl GlobalLimitExec {
57 pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
59 let cache = Self::compute_properties(&input);
60 GlobalLimitExec {
61 input,
62 skip,
63 fetch,
64 metrics: ExecutionPlanMetricsSet::new(),
65 cache,
66 }
67 }
68
69 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
71 &self.input
72 }
73
74 pub fn skip(&self) -> usize {
76 self.skip
77 }
78
79 pub fn fetch(&self) -> Option<usize> {
81 self.fetch
82 }
83
84 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
86 PlanProperties::new(
87 input.equivalence_properties().clone(), Partitioning::UnknownPartitioning(1), input.pipeline_behavior(),
90 Boundedness::Bounded,
92 )
93 }
94}
95
96impl DisplayAs for GlobalLimitExec {
97 fn fmt_as(
98 &self,
99 t: DisplayFormatType,
100 f: &mut std::fmt::Formatter,
101 ) -> std::fmt::Result {
102 match t {
103 DisplayFormatType::Default | DisplayFormatType::Verbose => {
104 write!(
105 f,
106 "GlobalLimitExec: skip={}, fetch={}",
107 self.skip,
108 self.fetch
109 .map_or_else(|| "None".to_string(), |x| x.to_string())
110 )
111 }
112 DisplayFormatType::TreeRender => {
113 if let Some(fetch) = self.fetch {
114 writeln!(f, "limit={fetch}")?;
115 }
116 write!(f, "skip={}", self.skip)
117 }
118 }
119 }
120}
121
122impl ExecutionPlan for GlobalLimitExec {
123 fn name(&self) -> &'static str {
124 "GlobalLimitExec"
125 }
126
127 fn as_any(&self) -> &dyn Any {
129 self
130 }
131
132 fn properties(&self) -> &PlanProperties {
133 &self.cache
134 }
135
136 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
137 vec![&self.input]
138 }
139
140 fn required_input_distribution(&self) -> Vec<Distribution> {
141 vec![Distribution::SinglePartition]
142 }
143
144 fn maintains_input_order(&self) -> Vec<bool> {
145 vec![true]
146 }
147
148 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
149 vec![false]
150 }
151
152 fn with_new_children(
153 self: Arc<Self>,
154 children: Vec<Arc<dyn ExecutionPlan>>,
155 ) -> Result<Arc<dyn ExecutionPlan>> {
156 Ok(Arc::new(GlobalLimitExec::new(
157 Arc::clone(&children[0]),
158 self.skip,
159 self.fetch,
160 )))
161 }
162
163 fn execute(
164 &self,
165 partition: usize,
166 context: Arc<TaskContext>,
167 ) -> Result<SendableRecordBatchStream> {
168 trace!("Start GlobalLimitExec::execute for partition: {partition}");
169 assert_eq_or_internal_err!(
171 partition,
172 0,
173 "GlobalLimitExec invalid partition {partition}"
174 );
175
176 assert_eq_or_internal_err!(
178 self.input.output_partitioning().partition_count(),
179 1,
180 "GlobalLimitExec requires a single input partition"
181 );
182
183 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
184 let stream = self.input.execute(0, context)?;
185 Ok(Box::pin(LimitStream::new(
186 stream,
187 self.skip,
188 self.fetch,
189 baseline_metrics,
190 )))
191 }
192
193 fn metrics(&self) -> Option<MetricsSet> {
194 Some(self.metrics.clone_inner())
195 }
196
197 fn statistics(&self) -> Result<Statistics> {
198 self.partition_statistics(None)
199 }
200
201 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
202 self.input
203 .partition_statistics(partition)?
204 .with_fetch(self.fetch, self.skip, 1)
205 }
206
207 fn fetch(&self) -> Option<usize> {
208 self.fetch
209 }
210
211 fn supports_limit_pushdown(&self) -> bool {
212 true
213 }
214}
215
216#[derive(Debug)]
218pub struct LocalLimitExec {
219 input: Arc<dyn ExecutionPlan>,
221 fetch: usize,
223 metrics: ExecutionPlanMetricsSet,
225 cache: PlanProperties,
226}
227
228impl LocalLimitExec {
229 pub fn new(input: Arc<dyn ExecutionPlan>, fetch: usize) -> Self {
231 let cache = Self::compute_properties(&input);
232 Self {
233 input,
234 fetch,
235 metrics: ExecutionPlanMetricsSet::new(),
236 cache,
237 }
238 }
239
240 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
242 &self.input
243 }
244
245 pub fn fetch(&self) -> usize {
247 self.fetch
248 }
249
250 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
252 PlanProperties::new(
253 input.equivalence_properties().clone(), input.output_partitioning().clone(), input.pipeline_behavior(),
256 Boundedness::Bounded,
258 )
259 }
260}
261
262impl DisplayAs for LocalLimitExec {
263 fn fmt_as(
264 &self,
265 t: DisplayFormatType,
266 f: &mut std::fmt::Formatter,
267 ) -> std::fmt::Result {
268 match t {
269 DisplayFormatType::Default | DisplayFormatType::Verbose => {
270 write!(f, "LocalLimitExec: fetch={}", self.fetch)
271 }
272 DisplayFormatType::TreeRender => {
273 write!(f, "limit={}", self.fetch)
274 }
275 }
276 }
277}
278
279impl ExecutionPlan for LocalLimitExec {
280 fn name(&self) -> &'static str {
281 "LocalLimitExec"
282 }
283
284 fn as_any(&self) -> &dyn Any {
286 self
287 }
288
289 fn properties(&self) -> &PlanProperties {
290 &self.cache
291 }
292
293 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
294 vec![&self.input]
295 }
296
297 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
298 vec![false]
299 }
300
301 fn maintains_input_order(&self) -> Vec<bool> {
302 vec![true]
303 }
304
305 fn with_new_children(
306 self: Arc<Self>,
307 children: Vec<Arc<dyn ExecutionPlan>>,
308 ) -> Result<Arc<dyn ExecutionPlan>> {
309 match children.len() {
310 1 => Ok(Arc::new(LocalLimitExec::new(
311 Arc::clone(&children[0]),
312 self.fetch,
313 ))),
314 _ => internal_err!("LocalLimitExec wrong number of children"),
315 }
316 }
317
318 fn execute(
319 &self,
320 partition: usize,
321 context: Arc<TaskContext>,
322 ) -> Result<SendableRecordBatchStream> {
323 trace!(
324 "Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}",
325 partition,
326 context.session_id(),
327 context.task_id()
328 );
329 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
330 let stream = self.input.execute(partition, context)?;
331 Ok(Box::pin(LimitStream::new(
332 stream,
333 0,
334 Some(self.fetch),
335 baseline_metrics,
336 )))
337 }
338
339 fn metrics(&self) -> Option<MetricsSet> {
340 Some(self.metrics.clone_inner())
341 }
342
343 fn statistics(&self) -> Result<Statistics> {
344 self.partition_statistics(None)
345 }
346
347 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
348 self.input
349 .partition_statistics(partition)?
350 .with_fetch(Some(self.fetch), 0, 1)
351 }
352
353 fn fetch(&self) -> Option<usize> {
354 Some(self.fetch)
355 }
356
357 fn supports_limit_pushdown(&self) -> bool {
358 true
359 }
360
361 fn cardinality_effect(&self) -> CardinalityEffect {
362 CardinalityEffect::LowerEqual
363 }
364}
365
366pub struct LimitStream {
368 skip: usize,
370 fetch: usize,
372 input: Option<SendableRecordBatchStream>,
375 schema: SchemaRef,
377 baseline_metrics: BaselineMetrics,
379}
380
381impl LimitStream {
382 pub fn new(
383 input: SendableRecordBatchStream,
384 skip: usize,
385 fetch: Option<usize>,
386 baseline_metrics: BaselineMetrics,
387 ) -> Self {
388 let schema = input.schema();
389 Self {
390 skip,
391 fetch: fetch.unwrap_or(usize::MAX),
392 input: Some(input),
393 schema,
394 baseline_metrics,
395 }
396 }
397
398 fn poll_and_skip(
399 &mut self,
400 cx: &mut Context<'_>,
401 ) -> Poll<Option<Result<RecordBatch>>> {
402 let input = self.input.as_mut().unwrap();
403 loop {
404 let poll = input.poll_next_unpin(cx);
405 let poll = poll.map_ok(|batch| {
406 if batch.num_rows() <= self.skip {
407 self.skip -= batch.num_rows();
408 RecordBatch::new_empty(input.schema())
409 } else {
410 let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip);
411 self.skip = 0;
412 new_batch
413 }
414 });
415
416 match &poll {
417 Poll::Ready(Some(Ok(batch))) => {
418 if batch.num_rows() > 0 {
419 break poll;
420 } else {
421 }
423 }
424 Poll::Ready(Some(Err(_e))) => break poll,
425 Poll::Ready(None) => break poll,
426 Poll::Pending => break poll,
427 }
428 }
429 }
430
431 fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
433 let _timer = self.baseline_metrics.elapsed_compute().timer();
435 if self.fetch == 0 {
436 self.input = None; None
438 } else if batch.num_rows() < self.fetch {
439 self.fetch -= batch.num_rows();
441 Some(batch)
442 } else if batch.num_rows() >= self.fetch {
443 let batch_rows = self.fetch;
444 self.fetch = 0;
445 self.input = None; Some(batch.slice(0, batch_rows))
449 } else {
450 unreachable!()
451 }
452 }
453}
454
455impl Stream for LimitStream {
456 type Item = Result<RecordBatch>;
457
458 fn poll_next(
459 mut self: Pin<&mut Self>,
460 cx: &mut Context<'_>,
461 ) -> Poll<Option<Self::Item>> {
462 let fetch_started = self.skip == 0;
463 let poll = match &mut self.input {
464 Some(input) => {
465 let poll = if fetch_started {
466 input.poll_next_unpin(cx)
467 } else {
468 self.poll_and_skip(cx)
469 };
470
471 poll.map(|x| match x {
472 Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
473 other => other,
474 })
475 }
476 None => Poll::Ready(None),
478 };
479
480 self.baseline_metrics.record_poll(poll)
481 }
482}
483
484impl RecordBatchStream for LimitStream {
485 fn schema(&self) -> SchemaRef {
487 Arc::clone(&self.schema)
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494 use crate::coalesce_partitions::CoalescePartitionsExec;
495 use crate::common::collect;
496 use crate::test;
497
498 use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
499 use arrow::array::RecordBatchOptions;
500 use arrow::datatypes::Schema;
501 use datafusion_common::stats::Precision;
502 use datafusion_physical_expr::PhysicalExpr;
503 use datafusion_physical_expr::expressions::col;
504
505 #[tokio::test]
506 async fn limit() -> Result<()> {
507 let task_ctx = Arc::new(TaskContext::default());
508
509 let num_partitions = 4;
510 let csv = test::scan_partitioned(num_partitions);
511
512 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
514
515 let limit =
516 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));
517
518 let iter = limit.execute(0, task_ctx)?;
520 let batches = collect(iter).await?;
521
522 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
524 assert_eq!(row_count, 7);
525
526 Ok(())
527 }
528
529 #[tokio::test]
530 async fn limit_early_shutdown() -> Result<()> {
531 let batches = vec![
532 test::make_partition(5),
533 test::make_partition(10),
534 test::make_partition(15),
535 test::make_partition(20),
536 test::make_partition(25),
537 ];
538 let input = test::exec::TestStream::new(batches);
539
540 let index = input.index();
541 assert_eq!(index.value(), 0);
542
543 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
546 let limit_stream =
547 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
548 assert_eq!(index.value(), 0);
549
550 let results = collect(Box::pin(limit_stream)).await.unwrap();
551 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
552 assert_eq!(num_rows, 6);
554
555 assert_eq!(index.value(), 2);
557
558 Ok(())
559 }
560
561 #[tokio::test]
562 async fn limit_equals_batch_size() -> Result<()> {
563 let batches = vec![
564 test::make_partition(6),
565 test::make_partition(6),
566 test::make_partition(6),
567 ];
568 let input = test::exec::TestStream::new(batches);
569
570 let index = input.index();
571 assert_eq!(index.value(), 0);
572
573 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
576 let limit_stream =
577 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
578 assert_eq!(index.value(), 0);
579
580 let results = collect(Box::pin(limit_stream)).await.unwrap();
581 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
582 assert_eq!(num_rows, 6);
584
585 assert_eq!(index.value(), 1);
587
588 Ok(())
589 }
590
591 #[tokio::test]
592 async fn limit_no_column() -> Result<()> {
593 let batches = vec![
594 make_batch_no_column(6),
595 make_batch_no_column(6),
596 make_batch_no_column(6),
597 ];
598 let input = test::exec::TestStream::new(batches);
599
600 let index = input.index();
601 assert_eq!(index.value(), 0);
602
603 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
606 let limit_stream =
607 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
608 assert_eq!(index.value(), 0);
609
610 let results = collect(Box::pin(limit_stream)).await.unwrap();
611 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
612 assert_eq!(num_rows, 6);
614
615 assert_eq!(index.value(), 1);
617
618 Ok(())
619 }
620
621 async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
623 let task_ctx = Arc::new(TaskContext::default());
624
625 let num_partitions = 4;
627 let csv = test::scan_partitioned(num_partitions);
628
629 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
630
631 let offset =
632 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
633
634 let iter = offset.execute(0, task_ctx)?;
636 let batches = collect(iter).await?;
637 Ok(batches.iter().map(|batch| batch.num_rows()).sum())
638 }
639
640 #[tokio::test]
641 async fn skip_none_fetch_none() -> Result<()> {
642 let row_count = skip_and_fetch(0, None).await?;
643 assert_eq!(row_count, 400);
644 Ok(())
645 }
646
647 #[tokio::test]
648 async fn skip_none_fetch_50() -> Result<()> {
649 let row_count = skip_and_fetch(0, Some(50)).await?;
650 assert_eq!(row_count, 50);
651 Ok(())
652 }
653
654 #[tokio::test]
655 async fn skip_3_fetch_none() -> Result<()> {
656 let row_count = skip_and_fetch(3, None).await?;
658 assert_eq!(row_count, 397);
659 Ok(())
660 }
661
662 #[tokio::test]
663 async fn skip_3_fetch_10_stats() -> Result<()> {
664 let row_count = skip_and_fetch(3, Some(10)).await?;
666 assert_eq!(row_count, 10);
667 Ok(())
668 }
669
670 #[tokio::test]
671 async fn skip_400_fetch_none() -> Result<()> {
672 let row_count = skip_and_fetch(400, None).await?;
673 assert_eq!(row_count, 0);
674 Ok(())
675 }
676
677 #[tokio::test]
678 async fn skip_400_fetch_1() -> Result<()> {
679 let row_count = skip_and_fetch(400, Some(1)).await?;
681 assert_eq!(row_count, 0);
682 Ok(())
683 }
684
685 #[tokio::test]
686 async fn skip_401_fetch_none() -> Result<()> {
687 let row_count = skip_and_fetch(401, None).await?;
689 assert_eq!(row_count, 0);
690 Ok(())
691 }
692
693 #[tokio::test]
694 async fn test_row_number_statistics_for_global_limit() -> Result<()> {
695 let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?;
696 assert_eq!(row_count, Precision::Exact(10));
697
698 let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
699 assert_eq!(row_count, Precision::Exact(10));
700
701 let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
702 assert_eq!(row_count, Precision::Exact(0));
703
704 let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
705 assert_eq!(row_count, Precision::Exact(2));
706
707 let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
708 assert_eq!(row_count, Precision::Exact(1));
709
710 let row_count = row_number_statistics_for_global_limit(398, None).await?;
711 assert_eq!(row_count, Precision::Exact(2));
712
713 let row_count =
714 row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
715 assert_eq!(row_count, Precision::Exact(400));
716
717 let row_count =
718 row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
719 assert_eq!(row_count, Precision::Exact(2));
720
721 let row_count =
722 row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
723 assert_eq!(row_count, Precision::Inexact(10));
724
725 let row_count =
726 row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
727 assert_eq!(row_count, Precision::Inexact(10));
728
729 let row_count =
730 row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
731 assert_eq!(row_count, Precision::Exact(0));
732
733 let row_count =
734 row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
735 assert_eq!(row_count, Precision::Inexact(2));
736
737 let row_count =
738 row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
739 assert_eq!(row_count, Precision::Inexact(1));
740
741 let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
742 assert_eq!(row_count, Precision::Inexact(2));
743
744 let row_count =
745 row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
746 assert_eq!(row_count, Precision::Inexact(400));
747
748 let row_count =
749 row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
750 assert_eq!(row_count, Precision::Inexact(2));
751
752 Ok(())
753 }
754
755 #[tokio::test]
756 async fn test_row_number_statistics_for_local_limit() -> Result<()> {
757 let row_count = row_number_statistics_for_local_limit(4, 10).await?;
758 assert_eq!(row_count, Precision::Exact(10));
759
760 Ok(())
761 }
762
763 async fn row_number_statistics_for_global_limit(
764 skip: usize,
765 fetch: Option<usize>,
766 ) -> Result<Precision<usize>> {
767 let num_partitions = 4;
768 let csv = test::scan_partitioned(num_partitions);
769
770 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
771
772 let offset =
773 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
774
775 Ok(offset.partition_statistics(None)?.num_rows)
776 }
777
778 pub fn build_group_by(
779 input_schema: &SchemaRef,
780 columns: Vec<String>,
781 ) -> PhysicalGroupBy {
782 let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
783 for column in columns.iter() {
784 group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
785 }
786 PhysicalGroupBy::new_single(group_by_expr.clone())
787 }
788
789 async fn row_number_inexact_statistics_for_global_limit(
790 skip: usize,
791 fetch: Option<usize>,
792 ) -> Result<Precision<usize>> {
793 let num_partitions = 4;
794 let csv = test::scan_partitioned(num_partitions);
795
796 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
797
798 let agg = AggregateExec::try_new(
800 AggregateMode::Final,
801 build_group_by(&csv.schema(), vec!["i".to_string()]),
802 vec![],
803 vec![],
804 Arc::clone(&csv),
805 Arc::clone(&csv.schema()),
806 )?;
807 let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
808
809 let offset = GlobalLimitExec::new(
810 Arc::new(CoalescePartitionsExec::new(agg_exec)),
811 skip,
812 fetch,
813 );
814
815 Ok(offset.partition_statistics(None)?.num_rows)
816 }
817
818 async fn row_number_statistics_for_local_limit(
819 num_partitions: usize,
820 fetch: usize,
821 ) -> Result<Precision<usize>> {
822 let csv = test::scan_partitioned(num_partitions);
823
824 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
825
826 let offset = LocalLimitExec::new(csv, fetch);
827
828 Ok(offset.partition_statistics(None)?.num_rows)
829 }
830
831 fn make_batch_no_column(sz: usize) -> RecordBatch {
833 let schema = Arc::new(Schema::empty());
834
835 let options = RecordBatchOptions::new().with_row_count(Option::from(sz));
836 RecordBatch::try_new_with_options(schema, vec![], &options).unwrap()
837 }
838}