1use std::any::Any;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use super::{
27 DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
28 SendableRecordBatchStream, Statistics,
29};
30use crate::execution_plan::{Boundedness, CardinalityEffect};
31use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning};
32
33use arrow::datatypes::SchemaRef;
34use arrow::record_batch::RecordBatch;
35use datafusion_common::{internal_err, Result};
36use datafusion_execution::TaskContext;
37
38use futures::stream::{Stream, StreamExt};
39use log::trace;
40
41#[derive(Debug, Clone)]
43pub struct GlobalLimitExec {
44 input: Arc<dyn ExecutionPlan>,
46 skip: usize,
48 fetch: Option<usize>,
51 metrics: ExecutionPlanMetricsSet,
53 cache: PlanProperties,
54}
55
56impl GlobalLimitExec {
57 pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
59 let cache = Self::compute_properties(&input);
60 GlobalLimitExec {
61 input,
62 skip,
63 fetch,
64 metrics: ExecutionPlanMetricsSet::new(),
65 cache,
66 }
67 }
68
69 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
71 &self.input
72 }
73
74 pub fn skip(&self) -> usize {
76 self.skip
77 }
78
79 pub fn fetch(&self) -> Option<usize> {
81 self.fetch
82 }
83
84 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
86 PlanProperties::new(
87 input.equivalence_properties().clone(), Partitioning::UnknownPartitioning(1), input.pipeline_behavior(),
90 Boundedness::Bounded,
92 )
93 }
94}
95
96impl DisplayAs for GlobalLimitExec {
97 fn fmt_as(
98 &self,
99 t: DisplayFormatType,
100 f: &mut std::fmt::Formatter,
101 ) -> std::fmt::Result {
102 match t {
103 DisplayFormatType::Default | DisplayFormatType::Verbose => {
104 write!(
105 f,
106 "GlobalLimitExec: skip={}, fetch={}",
107 self.skip,
108 self.fetch.map_or("None".to_string(), |x| x.to_string())
109 )
110 }
111 DisplayFormatType::TreeRender => {
112 if let Some(fetch) = self.fetch {
113 writeln!(f, "limit={fetch}")?;
114 }
115 write!(f, "skip={}", self.skip)
116 }
117 }
118 }
119}
120
121impl ExecutionPlan for GlobalLimitExec {
122 fn name(&self) -> &'static str {
123 "GlobalLimitExec"
124 }
125
126 fn as_any(&self) -> &dyn Any {
128 self
129 }
130
131 fn properties(&self) -> &PlanProperties {
132 &self.cache
133 }
134
135 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
136 vec![&self.input]
137 }
138
139 fn required_input_distribution(&self) -> Vec<Distribution> {
140 vec![Distribution::SinglePartition]
141 }
142
143 fn maintains_input_order(&self) -> Vec<bool> {
144 vec![true]
145 }
146
147 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
148 vec![false]
149 }
150
151 fn with_new_children(
152 self: Arc<Self>,
153 children: Vec<Arc<dyn ExecutionPlan>>,
154 ) -> Result<Arc<dyn ExecutionPlan>> {
155 Ok(Arc::new(GlobalLimitExec::new(
156 Arc::clone(&children[0]),
157 self.skip,
158 self.fetch,
159 )))
160 }
161
162 fn execute(
163 &self,
164 partition: usize,
165 context: Arc<TaskContext>,
166 ) -> Result<SendableRecordBatchStream> {
167 trace!("Start GlobalLimitExec::execute for partition: {partition}");
168 if 0 != partition {
170 return internal_err!("GlobalLimitExec invalid partition {partition}");
171 }
172
173 if 1 != self.input.output_partitioning().partition_count() {
175 return internal_err!("GlobalLimitExec requires a single input partition");
176 }
177
178 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
179 let stream = self.input.execute(0, context)?;
180 Ok(Box::pin(LimitStream::new(
181 stream,
182 self.skip,
183 self.fetch,
184 baseline_metrics,
185 )))
186 }
187
188 fn metrics(&self) -> Option<MetricsSet> {
189 Some(self.metrics.clone_inner())
190 }
191
192 fn statistics(&self) -> Result<Statistics> {
193 self.partition_statistics(None)
194 }
195
196 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
197 self.input.partition_statistics(partition)?.with_fetch(
198 self.schema(),
199 self.fetch,
200 self.skip,
201 1,
202 )
203 }
204
205 fn fetch(&self) -> Option<usize> {
206 self.fetch
207 }
208
209 fn supports_limit_pushdown(&self) -> bool {
210 true
211 }
212}
213
214#[derive(Debug)]
216pub struct LocalLimitExec {
217 input: Arc<dyn ExecutionPlan>,
219 fetch: usize,
221 metrics: ExecutionPlanMetricsSet,
223 cache: PlanProperties,
224}
225
226impl LocalLimitExec {
227 pub fn new(input: Arc<dyn ExecutionPlan>, fetch: usize) -> Self {
229 let cache = Self::compute_properties(&input);
230 Self {
231 input,
232 fetch,
233 metrics: ExecutionPlanMetricsSet::new(),
234 cache,
235 }
236 }
237
238 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
240 &self.input
241 }
242
243 pub fn fetch(&self) -> usize {
245 self.fetch
246 }
247
248 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
250 PlanProperties::new(
251 input.equivalence_properties().clone(), input.output_partitioning().clone(), input.pipeline_behavior(),
254 Boundedness::Bounded,
256 )
257 }
258}
259
260impl DisplayAs for LocalLimitExec {
261 fn fmt_as(
262 &self,
263 t: DisplayFormatType,
264 f: &mut std::fmt::Formatter,
265 ) -> std::fmt::Result {
266 match t {
267 DisplayFormatType::Default | DisplayFormatType::Verbose => {
268 write!(f, "LocalLimitExec: fetch={}", self.fetch)
269 }
270 DisplayFormatType::TreeRender => {
271 write!(f, "limit={}", self.fetch)
272 }
273 }
274 }
275}
276
277impl ExecutionPlan for LocalLimitExec {
278 fn name(&self) -> &'static str {
279 "LocalLimitExec"
280 }
281
282 fn as_any(&self) -> &dyn Any {
284 self
285 }
286
287 fn properties(&self) -> &PlanProperties {
288 &self.cache
289 }
290
291 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
292 vec![&self.input]
293 }
294
295 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
296 vec![false]
297 }
298
299 fn maintains_input_order(&self) -> Vec<bool> {
300 vec![true]
301 }
302
303 fn with_new_children(
304 self: Arc<Self>,
305 children: Vec<Arc<dyn ExecutionPlan>>,
306 ) -> Result<Arc<dyn ExecutionPlan>> {
307 match children.len() {
308 1 => Ok(Arc::new(LocalLimitExec::new(
309 Arc::clone(&children[0]),
310 self.fetch,
311 ))),
312 _ => internal_err!("LocalLimitExec wrong number of children"),
313 }
314 }
315
316 fn execute(
317 &self,
318 partition: usize,
319 context: Arc<TaskContext>,
320 ) -> Result<SendableRecordBatchStream> {
321 trace!("Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
322 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
323 let stream = self.input.execute(partition, context)?;
324 Ok(Box::pin(LimitStream::new(
325 stream,
326 0,
327 Some(self.fetch),
328 baseline_metrics,
329 )))
330 }
331
332 fn metrics(&self) -> Option<MetricsSet> {
333 Some(self.metrics.clone_inner())
334 }
335
336 fn statistics(&self) -> Result<Statistics> {
337 self.partition_statistics(None)
338 }
339
340 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
341 self.input.partition_statistics(partition)?.with_fetch(
342 self.schema(),
343 Some(self.fetch),
344 0,
345 1,
346 )
347 }
348
349 fn fetch(&self) -> Option<usize> {
350 Some(self.fetch)
351 }
352
353 fn supports_limit_pushdown(&self) -> bool {
354 true
355 }
356
357 fn cardinality_effect(&self) -> CardinalityEffect {
358 CardinalityEffect::LowerEqual
359 }
360}
361
362pub struct LimitStream {
364 skip: usize,
366 fetch: usize,
368 input: Option<SendableRecordBatchStream>,
371 schema: SchemaRef,
373 baseline_metrics: BaselineMetrics,
375}
376
377impl LimitStream {
378 pub fn new(
379 input: SendableRecordBatchStream,
380 skip: usize,
381 fetch: Option<usize>,
382 baseline_metrics: BaselineMetrics,
383 ) -> Self {
384 let schema = input.schema();
385 Self {
386 skip,
387 fetch: fetch.unwrap_or(usize::MAX),
388 input: Some(input),
389 schema,
390 baseline_metrics,
391 }
392 }
393
394 fn poll_and_skip(
395 &mut self,
396 cx: &mut Context<'_>,
397 ) -> Poll<Option<Result<RecordBatch>>> {
398 let input = self.input.as_mut().unwrap();
399 loop {
400 let poll = input.poll_next_unpin(cx);
401 let poll = poll.map_ok(|batch| {
402 if batch.num_rows() <= self.skip {
403 self.skip -= batch.num_rows();
404 RecordBatch::new_empty(input.schema())
405 } else {
406 let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip);
407 self.skip = 0;
408 new_batch
409 }
410 });
411
412 match &poll {
413 Poll::Ready(Some(Ok(batch))) => {
414 if batch.num_rows() > 0 {
415 break poll;
416 } else {
417 }
419 }
420 Poll::Ready(Some(Err(_e))) => break poll,
421 Poll::Ready(None) => break poll,
422 Poll::Pending => break poll,
423 }
424 }
425 }
426
427 fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
429 let _timer = self.baseline_metrics.elapsed_compute().timer();
431 if self.fetch == 0 {
432 self.input = None; None
434 } else if batch.num_rows() < self.fetch {
435 self.fetch -= batch.num_rows();
437 Some(batch)
438 } else if batch.num_rows() >= self.fetch {
439 let batch_rows = self.fetch;
440 self.fetch = 0;
441 self.input = None; Some(batch.slice(0, batch_rows))
445 } else {
446 unreachable!()
447 }
448 }
449}
450
451impl Stream for LimitStream {
452 type Item = Result<RecordBatch>;
453
454 fn poll_next(
455 mut self: Pin<&mut Self>,
456 cx: &mut Context<'_>,
457 ) -> Poll<Option<Self::Item>> {
458 let fetch_started = self.skip == 0;
459 let poll = match &mut self.input {
460 Some(input) => {
461 let poll = if fetch_started {
462 input.poll_next_unpin(cx)
463 } else {
464 self.poll_and_skip(cx)
465 };
466
467 poll.map(|x| match x {
468 Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
469 other => other,
470 })
471 }
472 None => Poll::Ready(None),
474 };
475
476 self.baseline_metrics.record_poll(poll)
477 }
478}
479
480impl RecordBatchStream for LimitStream {
481 fn schema(&self) -> SchemaRef {
483 Arc::clone(&self.schema)
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use crate::coalesce_partitions::CoalescePartitionsExec;
491 use crate::common::collect;
492 use crate::test;
493
494 use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
495 use arrow::array::RecordBatchOptions;
496 use arrow::datatypes::Schema;
497 use datafusion_common::stats::Precision;
498 use datafusion_physical_expr::expressions::col;
499 use datafusion_physical_expr::PhysicalExpr;
500
501 #[tokio::test]
502 async fn limit() -> Result<()> {
503 let task_ctx = Arc::new(TaskContext::default());
504
505 let num_partitions = 4;
506 let csv = test::scan_partitioned(num_partitions);
507
508 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
510
511 let limit =
512 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));
513
514 let iter = limit.execute(0, task_ctx)?;
516 let batches = collect(iter).await?;
517
518 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
520 assert_eq!(row_count, 7);
521
522 Ok(())
523 }
524
525 #[tokio::test]
526 async fn limit_early_shutdown() -> Result<()> {
527 let batches = vec![
528 test::make_partition(5),
529 test::make_partition(10),
530 test::make_partition(15),
531 test::make_partition(20),
532 test::make_partition(25),
533 ];
534 let input = test::exec::TestStream::new(batches);
535
536 let index = input.index();
537 assert_eq!(index.value(), 0);
538
539 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
542 let limit_stream =
543 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
544 assert_eq!(index.value(), 0);
545
546 let results = collect(Box::pin(limit_stream)).await.unwrap();
547 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
548 assert_eq!(num_rows, 6);
550
551 assert_eq!(index.value(), 2);
553
554 Ok(())
555 }
556
557 #[tokio::test]
558 async fn limit_equals_batch_size() -> Result<()> {
559 let batches = vec![
560 test::make_partition(6),
561 test::make_partition(6),
562 test::make_partition(6),
563 ];
564 let input = test::exec::TestStream::new(batches);
565
566 let index = input.index();
567 assert_eq!(index.value(), 0);
568
569 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
572 let limit_stream =
573 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
574 assert_eq!(index.value(), 0);
575
576 let results = collect(Box::pin(limit_stream)).await.unwrap();
577 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
578 assert_eq!(num_rows, 6);
580
581 assert_eq!(index.value(), 1);
583
584 Ok(())
585 }
586
587 #[tokio::test]
588 async fn limit_no_column() -> Result<()> {
589 let batches = vec![
590 make_batch_no_column(6),
591 make_batch_no_column(6),
592 make_batch_no_column(6),
593 ];
594 let input = test::exec::TestStream::new(batches);
595
596 let index = input.index();
597 assert_eq!(index.value(), 0);
598
599 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
602 let limit_stream =
603 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
604 assert_eq!(index.value(), 0);
605
606 let results = collect(Box::pin(limit_stream)).await.unwrap();
607 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
608 assert_eq!(num_rows, 6);
610
611 assert_eq!(index.value(), 1);
613
614 Ok(())
615 }
616
617 async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
619 let task_ctx = Arc::new(TaskContext::default());
620
621 let num_partitions = 4;
623 let csv = test::scan_partitioned(num_partitions);
624
625 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
626
627 let offset =
628 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
629
630 let iter = offset.execute(0, task_ctx)?;
632 let batches = collect(iter).await?;
633 Ok(batches.iter().map(|batch| batch.num_rows()).sum())
634 }
635
636 #[tokio::test]
637 async fn skip_none_fetch_none() -> Result<()> {
638 let row_count = skip_and_fetch(0, None).await?;
639 assert_eq!(row_count, 400);
640 Ok(())
641 }
642
643 #[tokio::test]
644 async fn skip_none_fetch_50() -> Result<()> {
645 let row_count = skip_and_fetch(0, Some(50)).await?;
646 assert_eq!(row_count, 50);
647 Ok(())
648 }
649
650 #[tokio::test]
651 async fn skip_3_fetch_none() -> Result<()> {
652 let row_count = skip_and_fetch(3, None).await?;
654 assert_eq!(row_count, 397);
655 Ok(())
656 }
657
658 #[tokio::test]
659 async fn skip_3_fetch_10_stats() -> Result<()> {
660 let row_count = skip_and_fetch(3, Some(10)).await?;
662 assert_eq!(row_count, 10);
663 Ok(())
664 }
665
666 #[tokio::test]
667 async fn skip_400_fetch_none() -> Result<()> {
668 let row_count = skip_and_fetch(400, None).await?;
669 assert_eq!(row_count, 0);
670 Ok(())
671 }
672
673 #[tokio::test]
674 async fn skip_400_fetch_1() -> Result<()> {
675 let row_count = skip_and_fetch(400, Some(1)).await?;
677 assert_eq!(row_count, 0);
678 Ok(())
679 }
680
681 #[tokio::test]
682 async fn skip_401_fetch_none() -> Result<()> {
683 let row_count = skip_and_fetch(401, None).await?;
685 assert_eq!(row_count, 0);
686 Ok(())
687 }
688
689 #[tokio::test]
690 async fn test_row_number_statistics_for_global_limit() -> Result<()> {
691 let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?;
692 assert_eq!(row_count, Precision::Exact(10));
693
694 let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
695 assert_eq!(row_count, Precision::Exact(10));
696
697 let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
698 assert_eq!(row_count, Precision::Exact(0));
699
700 let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
701 assert_eq!(row_count, Precision::Exact(2));
702
703 let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
704 assert_eq!(row_count, Precision::Exact(1));
705
706 let row_count = row_number_statistics_for_global_limit(398, None).await?;
707 assert_eq!(row_count, Precision::Exact(2));
708
709 let row_count =
710 row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
711 assert_eq!(row_count, Precision::Exact(400));
712
713 let row_count =
714 row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
715 assert_eq!(row_count, Precision::Exact(2));
716
717 let row_count =
718 row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
719 assert_eq!(row_count, Precision::Inexact(10));
720
721 let row_count =
722 row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
723 assert_eq!(row_count, Precision::Inexact(10));
724
725 let row_count =
726 row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
727 assert_eq!(row_count, Precision::Exact(0));
728
729 let row_count =
730 row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
731 assert_eq!(row_count, Precision::Inexact(2));
732
733 let row_count =
734 row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
735 assert_eq!(row_count, Precision::Inexact(1));
736
737 let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
738 assert_eq!(row_count, Precision::Inexact(2));
739
740 let row_count =
741 row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
742 assert_eq!(row_count, Precision::Inexact(400));
743
744 let row_count =
745 row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
746 assert_eq!(row_count, Precision::Inexact(2));
747
748 Ok(())
749 }
750
751 #[tokio::test]
752 async fn test_row_number_statistics_for_local_limit() -> Result<()> {
753 let row_count = row_number_statistics_for_local_limit(4, 10).await?;
754 assert_eq!(row_count, Precision::Exact(10));
755
756 Ok(())
757 }
758
759 async fn row_number_statistics_for_global_limit(
760 skip: usize,
761 fetch: Option<usize>,
762 ) -> Result<Precision<usize>> {
763 let num_partitions = 4;
764 let csv = test::scan_partitioned(num_partitions);
765
766 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
767
768 let offset =
769 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
770
771 Ok(offset.partition_statistics(None)?.num_rows)
772 }
773
774 pub fn build_group_by(
775 input_schema: &SchemaRef,
776 columns: Vec<String>,
777 ) -> PhysicalGroupBy {
778 let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
779 for column in columns.iter() {
780 group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
781 }
782 PhysicalGroupBy::new_single(group_by_expr.clone())
783 }
784
785 async fn row_number_inexact_statistics_for_global_limit(
786 skip: usize,
787 fetch: Option<usize>,
788 ) -> Result<Precision<usize>> {
789 let num_partitions = 4;
790 let csv = test::scan_partitioned(num_partitions);
791
792 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
793
794 let agg = AggregateExec::try_new(
796 AggregateMode::Final,
797 build_group_by(&csv.schema(), vec!["i".to_string()]),
798 vec![],
799 vec![],
800 Arc::clone(&csv),
801 Arc::clone(&csv.schema()),
802 )?;
803 let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
804
805 let offset = GlobalLimitExec::new(
806 Arc::new(CoalescePartitionsExec::new(agg_exec)),
807 skip,
808 fetch,
809 );
810
811 Ok(offset.partition_statistics(None)?.num_rows)
812 }
813
814 async fn row_number_statistics_for_local_limit(
815 num_partitions: usize,
816 fetch: usize,
817 ) -> Result<Precision<usize>> {
818 let csv = test::scan_partitioned(num_partitions);
819
820 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
821
822 let offset = LocalLimitExec::new(csv, fetch);
823
824 Ok(offset.partition_statistics(None)?.num_rows)
825 }
826
827 fn make_batch_no_column(sz: usize) -> RecordBatch {
829 let schema = Arc::new(Schema::empty());
830
831 let options = RecordBatchOptions::new().with_row_count(Option::from(sz));
832 RecordBatch::try_new_with_options(schema, vec![], &options).unwrap()
833 }
834}