1use std::any::Any;
21use std::pin::Pin;
22use std::sync::Arc;
23use std::task::{Context, Poll};
24
25use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use super::{
27 DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream,
28 SendableRecordBatchStream, Statistics,
29};
30use crate::execution_plan::{Boundedness, CardinalityEffect};
31use crate::{DisplayFormatType, Distribution, ExecutionPlan, Partitioning};
32
33use arrow::datatypes::SchemaRef;
34use arrow::record_batch::RecordBatch;
35use datafusion_common::{internal_err, Result};
36use datafusion_execution::TaskContext;
37
38use futures::stream::{Stream, StreamExt};
39use log::trace;
40
41#[derive(Debug, Clone)]
43pub struct GlobalLimitExec {
44 input: Arc<dyn ExecutionPlan>,
46 skip: usize,
48 fetch: Option<usize>,
51 metrics: ExecutionPlanMetricsSet,
53 cache: PlanProperties,
54}
55
56impl GlobalLimitExec {
57 pub fn new(input: Arc<dyn ExecutionPlan>, skip: usize, fetch: Option<usize>) -> Self {
59 let cache = Self::compute_properties(&input);
60 GlobalLimitExec {
61 input,
62 skip,
63 fetch,
64 metrics: ExecutionPlanMetricsSet::new(),
65 cache,
66 }
67 }
68
69 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
71 &self.input
72 }
73
74 pub fn skip(&self) -> usize {
76 self.skip
77 }
78
79 pub fn fetch(&self) -> Option<usize> {
81 self.fetch
82 }
83
84 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
86 PlanProperties::new(
87 input.equivalence_properties().clone(), Partitioning::UnknownPartitioning(1), input.pipeline_behavior(),
90 Boundedness::Bounded,
92 )
93 }
94}
95
96impl DisplayAs for GlobalLimitExec {
97 fn fmt_as(
98 &self,
99 t: DisplayFormatType,
100 f: &mut std::fmt::Formatter,
101 ) -> std::fmt::Result {
102 match t {
103 DisplayFormatType::Default | DisplayFormatType::Verbose => {
104 write!(
105 f,
106 "GlobalLimitExec: skip={}, fetch={}",
107 self.skip,
108 self.fetch
109 .map_or_else(|| "None".to_string(), |x| x.to_string())
110 )
111 }
112 DisplayFormatType::TreeRender => {
113 if let Some(fetch) = self.fetch {
114 writeln!(f, "limit={fetch}")?;
115 }
116 write!(f, "skip={}", self.skip)
117 }
118 }
119 }
120}
121
122impl ExecutionPlan for GlobalLimitExec {
123 fn name(&self) -> &'static str {
124 "GlobalLimitExec"
125 }
126
127 fn as_any(&self) -> &dyn Any {
129 self
130 }
131
132 fn properties(&self) -> &PlanProperties {
133 &self.cache
134 }
135
136 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
137 vec![&self.input]
138 }
139
140 fn required_input_distribution(&self) -> Vec<Distribution> {
141 vec![Distribution::SinglePartition]
142 }
143
144 fn maintains_input_order(&self) -> Vec<bool> {
145 vec![true]
146 }
147
148 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
149 vec![false]
150 }
151
152 fn with_new_children(
153 self: Arc<Self>,
154 children: Vec<Arc<dyn ExecutionPlan>>,
155 ) -> Result<Arc<dyn ExecutionPlan>> {
156 Ok(Arc::new(GlobalLimitExec::new(
157 Arc::clone(&children[0]),
158 self.skip,
159 self.fetch,
160 )))
161 }
162
163 fn execute(
164 &self,
165 partition: usize,
166 context: Arc<TaskContext>,
167 ) -> Result<SendableRecordBatchStream> {
168 trace!("Start GlobalLimitExec::execute for partition: {partition}");
169 if 0 != partition {
171 return internal_err!("GlobalLimitExec invalid partition {partition}");
172 }
173
174 if 1 != self.input.output_partitioning().partition_count() {
176 return internal_err!("GlobalLimitExec requires a single input partition");
177 }
178
179 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
180 let stream = self.input.execute(0, context)?;
181 Ok(Box::pin(LimitStream::new(
182 stream,
183 self.skip,
184 self.fetch,
185 baseline_metrics,
186 )))
187 }
188
189 fn metrics(&self) -> Option<MetricsSet> {
190 Some(self.metrics.clone_inner())
191 }
192
193 fn statistics(&self) -> Result<Statistics> {
194 self.partition_statistics(None)
195 }
196
197 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
198 self.input
199 .partition_statistics(partition)?
200 .with_fetch(self.fetch, self.skip, 1)
201 }
202
203 fn fetch(&self) -> Option<usize> {
204 self.fetch
205 }
206
207 fn supports_limit_pushdown(&self) -> bool {
208 true
209 }
210}
211
212#[derive(Debug)]
214pub struct LocalLimitExec {
215 input: Arc<dyn ExecutionPlan>,
217 fetch: usize,
219 metrics: ExecutionPlanMetricsSet,
221 cache: PlanProperties,
222}
223
224impl LocalLimitExec {
225 pub fn new(input: Arc<dyn ExecutionPlan>, fetch: usize) -> Self {
227 let cache = Self::compute_properties(&input);
228 Self {
229 input,
230 fetch,
231 metrics: ExecutionPlanMetricsSet::new(),
232 cache,
233 }
234 }
235
236 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
238 &self.input
239 }
240
241 pub fn fetch(&self) -> usize {
243 self.fetch
244 }
245
246 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
248 PlanProperties::new(
249 input.equivalence_properties().clone(), input.output_partitioning().clone(), input.pipeline_behavior(),
252 Boundedness::Bounded,
254 )
255 }
256}
257
258impl DisplayAs for LocalLimitExec {
259 fn fmt_as(
260 &self,
261 t: DisplayFormatType,
262 f: &mut std::fmt::Formatter,
263 ) -> std::fmt::Result {
264 match t {
265 DisplayFormatType::Default | DisplayFormatType::Verbose => {
266 write!(f, "LocalLimitExec: fetch={}", self.fetch)
267 }
268 DisplayFormatType::TreeRender => {
269 write!(f, "limit={}", self.fetch)
270 }
271 }
272 }
273}
274
275impl ExecutionPlan for LocalLimitExec {
276 fn name(&self) -> &'static str {
277 "LocalLimitExec"
278 }
279
280 fn as_any(&self) -> &dyn Any {
282 self
283 }
284
285 fn properties(&self) -> &PlanProperties {
286 &self.cache
287 }
288
289 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
290 vec![&self.input]
291 }
292
293 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
294 vec![false]
295 }
296
297 fn maintains_input_order(&self) -> Vec<bool> {
298 vec![true]
299 }
300
301 fn with_new_children(
302 self: Arc<Self>,
303 children: Vec<Arc<dyn ExecutionPlan>>,
304 ) -> Result<Arc<dyn ExecutionPlan>> {
305 match children.len() {
306 1 => Ok(Arc::new(LocalLimitExec::new(
307 Arc::clone(&children[0]),
308 self.fetch,
309 ))),
310 _ => internal_err!("LocalLimitExec wrong number of children"),
311 }
312 }
313
314 fn execute(
315 &self,
316 partition: usize,
317 context: Arc<TaskContext>,
318 ) -> Result<SendableRecordBatchStream> {
319 trace!("Start LocalLimitExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id());
320 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
321 let stream = self.input.execute(partition, context)?;
322 Ok(Box::pin(LimitStream::new(
323 stream,
324 0,
325 Some(self.fetch),
326 baseline_metrics,
327 )))
328 }
329
330 fn metrics(&self) -> Option<MetricsSet> {
331 Some(self.metrics.clone_inner())
332 }
333
334 fn statistics(&self) -> Result<Statistics> {
335 self.partition_statistics(None)
336 }
337
338 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
339 self.input
340 .partition_statistics(partition)?
341 .with_fetch(Some(self.fetch), 0, 1)
342 }
343
344 fn fetch(&self) -> Option<usize> {
345 Some(self.fetch)
346 }
347
348 fn supports_limit_pushdown(&self) -> bool {
349 true
350 }
351
352 fn cardinality_effect(&self) -> CardinalityEffect {
353 CardinalityEffect::LowerEqual
354 }
355}
356
357pub struct LimitStream {
359 skip: usize,
361 fetch: usize,
363 input: Option<SendableRecordBatchStream>,
366 schema: SchemaRef,
368 baseline_metrics: BaselineMetrics,
370}
371
372impl LimitStream {
373 pub fn new(
374 input: SendableRecordBatchStream,
375 skip: usize,
376 fetch: Option<usize>,
377 baseline_metrics: BaselineMetrics,
378 ) -> Self {
379 let schema = input.schema();
380 Self {
381 skip,
382 fetch: fetch.unwrap_or(usize::MAX),
383 input: Some(input),
384 schema,
385 baseline_metrics,
386 }
387 }
388
389 fn poll_and_skip(
390 &mut self,
391 cx: &mut Context<'_>,
392 ) -> Poll<Option<Result<RecordBatch>>> {
393 let input = self.input.as_mut().unwrap();
394 loop {
395 let poll = input.poll_next_unpin(cx);
396 let poll = poll.map_ok(|batch| {
397 if batch.num_rows() <= self.skip {
398 self.skip -= batch.num_rows();
399 RecordBatch::new_empty(input.schema())
400 } else {
401 let new_batch = batch.slice(self.skip, batch.num_rows() - self.skip);
402 self.skip = 0;
403 new_batch
404 }
405 });
406
407 match &poll {
408 Poll::Ready(Some(Ok(batch))) => {
409 if batch.num_rows() > 0 {
410 break poll;
411 } else {
412 }
414 }
415 Poll::Ready(Some(Err(_e))) => break poll,
416 Poll::Ready(None) => break poll,
417 Poll::Pending => break poll,
418 }
419 }
420 }
421
422 fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
424 let _timer = self.baseline_metrics.elapsed_compute().timer();
426 if self.fetch == 0 {
427 self.input = None; None
429 } else if batch.num_rows() < self.fetch {
430 self.fetch -= batch.num_rows();
432 Some(batch)
433 } else if batch.num_rows() >= self.fetch {
434 let batch_rows = self.fetch;
435 self.fetch = 0;
436 self.input = None; Some(batch.slice(0, batch_rows))
440 } else {
441 unreachable!()
442 }
443 }
444}
445
446impl Stream for LimitStream {
447 type Item = Result<RecordBatch>;
448
449 fn poll_next(
450 mut self: Pin<&mut Self>,
451 cx: &mut Context<'_>,
452 ) -> Poll<Option<Self::Item>> {
453 let fetch_started = self.skip == 0;
454 let poll = match &mut self.input {
455 Some(input) => {
456 let poll = if fetch_started {
457 input.poll_next_unpin(cx)
458 } else {
459 self.poll_and_skip(cx)
460 };
461
462 poll.map(|x| match x {
463 Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
464 other => other,
465 })
466 }
467 None => Poll::Ready(None),
469 };
470
471 self.baseline_metrics.record_poll(poll)
472 }
473}
474
475impl RecordBatchStream for LimitStream {
476 fn schema(&self) -> SchemaRef {
478 Arc::clone(&self.schema)
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485 use crate::coalesce_partitions::CoalescePartitionsExec;
486 use crate::common::collect;
487 use crate::test;
488
489 use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
490 use arrow::array::RecordBatchOptions;
491 use arrow::datatypes::Schema;
492 use datafusion_common::stats::Precision;
493 use datafusion_physical_expr::expressions::col;
494 use datafusion_physical_expr::PhysicalExpr;
495
496 #[tokio::test]
497 async fn limit() -> Result<()> {
498 let task_ctx = Arc::new(TaskContext::default());
499
500 let num_partitions = 4;
501 let csv = test::scan_partitioned(num_partitions);
502
503 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
505
506 let limit =
507 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), 0, Some(7));
508
509 let iter = limit.execute(0, task_ctx)?;
511 let batches = collect(iter).await?;
512
513 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
515 assert_eq!(row_count, 7);
516
517 Ok(())
518 }
519
520 #[tokio::test]
521 async fn limit_early_shutdown() -> Result<()> {
522 let batches = vec![
523 test::make_partition(5),
524 test::make_partition(10),
525 test::make_partition(15),
526 test::make_partition(20),
527 test::make_partition(25),
528 ];
529 let input = test::exec::TestStream::new(batches);
530
531 let index = input.index();
532 assert_eq!(index.value(), 0);
533
534 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
537 let limit_stream =
538 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
539 assert_eq!(index.value(), 0);
540
541 let results = collect(Box::pin(limit_stream)).await.unwrap();
542 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
543 assert_eq!(num_rows, 6);
545
546 assert_eq!(index.value(), 2);
548
549 Ok(())
550 }
551
552 #[tokio::test]
553 async fn limit_equals_batch_size() -> Result<()> {
554 let batches = vec![
555 test::make_partition(6),
556 test::make_partition(6),
557 test::make_partition(6),
558 ];
559 let input = test::exec::TestStream::new(batches);
560
561 let index = input.index();
562 assert_eq!(index.value(), 0);
563
564 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
567 let limit_stream =
568 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
569 assert_eq!(index.value(), 0);
570
571 let results = collect(Box::pin(limit_stream)).await.unwrap();
572 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
573 assert_eq!(num_rows, 6);
575
576 assert_eq!(index.value(), 1);
578
579 Ok(())
580 }
581
582 #[tokio::test]
583 async fn limit_no_column() -> Result<()> {
584 let batches = vec![
585 make_batch_no_column(6),
586 make_batch_no_column(6),
587 make_batch_no_column(6),
588 ];
589 let input = test::exec::TestStream::new(batches);
590
591 let index = input.index();
592 assert_eq!(index.value(), 0);
593
594 let baseline_metrics = BaselineMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
597 let limit_stream =
598 LimitStream::new(Box::pin(input), 0, Some(6), baseline_metrics);
599 assert_eq!(index.value(), 0);
600
601 let results = collect(Box::pin(limit_stream)).await.unwrap();
602 let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
603 assert_eq!(num_rows, 6);
605
606 assert_eq!(index.value(), 1);
608
609 Ok(())
610 }
611
612 async fn skip_and_fetch(skip: usize, fetch: Option<usize>) -> Result<usize> {
614 let task_ctx = Arc::new(TaskContext::default());
615
616 let num_partitions = 4;
618 let csv = test::scan_partitioned(num_partitions);
619
620 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
621
622 let offset =
623 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
624
625 let iter = offset.execute(0, task_ctx)?;
627 let batches = collect(iter).await?;
628 Ok(batches.iter().map(|batch| batch.num_rows()).sum())
629 }
630
631 #[tokio::test]
632 async fn skip_none_fetch_none() -> Result<()> {
633 let row_count = skip_and_fetch(0, None).await?;
634 assert_eq!(row_count, 400);
635 Ok(())
636 }
637
638 #[tokio::test]
639 async fn skip_none_fetch_50() -> Result<()> {
640 let row_count = skip_and_fetch(0, Some(50)).await?;
641 assert_eq!(row_count, 50);
642 Ok(())
643 }
644
645 #[tokio::test]
646 async fn skip_3_fetch_none() -> Result<()> {
647 let row_count = skip_and_fetch(3, None).await?;
649 assert_eq!(row_count, 397);
650 Ok(())
651 }
652
653 #[tokio::test]
654 async fn skip_3_fetch_10_stats() -> Result<()> {
655 let row_count = skip_and_fetch(3, Some(10)).await?;
657 assert_eq!(row_count, 10);
658 Ok(())
659 }
660
661 #[tokio::test]
662 async fn skip_400_fetch_none() -> Result<()> {
663 let row_count = skip_and_fetch(400, None).await?;
664 assert_eq!(row_count, 0);
665 Ok(())
666 }
667
668 #[tokio::test]
669 async fn skip_400_fetch_1() -> Result<()> {
670 let row_count = skip_and_fetch(400, Some(1)).await?;
672 assert_eq!(row_count, 0);
673 Ok(())
674 }
675
676 #[tokio::test]
677 async fn skip_401_fetch_none() -> Result<()> {
678 let row_count = skip_and_fetch(401, None).await?;
680 assert_eq!(row_count, 0);
681 Ok(())
682 }
683
684 #[tokio::test]
685 async fn test_row_number_statistics_for_global_limit() -> Result<()> {
686 let row_count = row_number_statistics_for_global_limit(0, Some(10)).await?;
687 assert_eq!(row_count, Precision::Exact(10));
688
689 let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
690 assert_eq!(row_count, Precision::Exact(10));
691
692 let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
693 assert_eq!(row_count, Precision::Exact(0));
694
695 let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
696 assert_eq!(row_count, Precision::Exact(2));
697
698 let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
699 assert_eq!(row_count, Precision::Exact(1));
700
701 let row_count = row_number_statistics_for_global_limit(398, None).await?;
702 assert_eq!(row_count, Precision::Exact(2));
703
704 let row_count =
705 row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
706 assert_eq!(row_count, Precision::Exact(400));
707
708 let row_count =
709 row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
710 assert_eq!(row_count, Precision::Exact(2));
711
712 let row_count =
713 row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
714 assert_eq!(row_count, Precision::Inexact(10));
715
716 let row_count =
717 row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
718 assert_eq!(row_count, Precision::Inexact(10));
719
720 let row_count =
721 row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
722 assert_eq!(row_count, Precision::Exact(0));
723
724 let row_count =
725 row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
726 assert_eq!(row_count, Precision::Inexact(2));
727
728 let row_count =
729 row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
730 assert_eq!(row_count, Precision::Inexact(1));
731
732 let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
733 assert_eq!(row_count, Precision::Inexact(2));
734
735 let row_count =
736 row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
737 assert_eq!(row_count, Precision::Inexact(400));
738
739 let row_count =
740 row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
741 assert_eq!(row_count, Precision::Inexact(2));
742
743 Ok(())
744 }
745
746 #[tokio::test]
747 async fn test_row_number_statistics_for_local_limit() -> Result<()> {
748 let row_count = row_number_statistics_for_local_limit(4, 10).await?;
749 assert_eq!(row_count, Precision::Exact(10));
750
751 Ok(())
752 }
753
754 async fn row_number_statistics_for_global_limit(
755 skip: usize,
756 fetch: Option<usize>,
757 ) -> Result<Precision<usize>> {
758 let num_partitions = 4;
759 let csv = test::scan_partitioned(num_partitions);
760
761 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
762
763 let offset =
764 GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(csv)), skip, fetch);
765
766 Ok(offset.partition_statistics(None)?.num_rows)
767 }
768
769 pub fn build_group_by(
770 input_schema: &SchemaRef,
771 columns: Vec<String>,
772 ) -> PhysicalGroupBy {
773 let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
774 for column in columns.iter() {
775 group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
776 }
777 PhysicalGroupBy::new_single(group_by_expr.clone())
778 }
779
780 async fn row_number_inexact_statistics_for_global_limit(
781 skip: usize,
782 fetch: Option<usize>,
783 ) -> Result<Precision<usize>> {
784 let num_partitions = 4;
785 let csv = test::scan_partitioned(num_partitions);
786
787 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
788
789 let agg = AggregateExec::try_new(
791 AggregateMode::Final,
792 build_group_by(&csv.schema(), vec!["i".to_string()]),
793 vec![],
794 vec![],
795 Arc::clone(&csv),
796 Arc::clone(&csv.schema()),
797 )?;
798 let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
799
800 let offset = GlobalLimitExec::new(
801 Arc::new(CoalescePartitionsExec::new(agg_exec)),
802 skip,
803 fetch,
804 );
805
806 Ok(offset.partition_statistics(None)?.num_rows)
807 }
808
809 async fn row_number_statistics_for_local_limit(
810 num_partitions: usize,
811 fetch: usize,
812 ) -> Result<Precision<usize>> {
813 let csv = test::scan_partitioned(num_partitions);
814
815 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
816
817 let offset = LocalLimitExec::new(csv, fetch);
818
819 Ok(offset.partition_statistics(None)?.num_rows)
820 }
821
822 fn make_batch_no_column(sz: usize) -> RecordBatch {
824 let schema = Arc::new(Schema::empty());
825
826 let options = RecordBatchOptions::new().with_row_count(Option::from(sz));
827 RecordBatch::try_new_with_options(schema, vec![], &options).unwrap()
828 }
829}