1use std::any::Any;
21use std::sync::Arc;
22
23use crate::common::spawn_buffered;
24use crate::limit::LimitStream;
25use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
26use crate::projection::{ProjectionExec, make_with_child, update_ordering};
27use crate::sorts::streaming_merge::StreamingMergeBuilder;
28use crate::{
29 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
30 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
31 check_if_same_properties,
32};
33
34use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
35use datafusion_execution::TaskContext;
36use datafusion_execution::memory_pool::MemoryConsumer;
37use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
38
39use crate::execution_plan::{EvaluationType, SchedulingType};
40use log::{debug, trace};
41
42#[derive(Debug, Clone)]
87pub struct SortPreservingMergeExec {
88 input: Arc<dyn ExecutionPlan>,
90 expr: LexOrdering,
92 metrics: ExecutionPlanMetricsSet,
94 fetch: Option<usize>,
96 cache: Arc<PlanProperties>,
98 enable_round_robin_repartition: bool,
102}
103
104impl SortPreservingMergeExec {
105 pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
107 let cache = Self::compute_properties(&input, expr.clone());
108 Self {
109 input,
110 expr,
111 metrics: ExecutionPlanMetricsSet::new(),
112 fetch: None,
113 cache: Arc::new(cache),
114 enable_round_robin_repartition: true,
115 }
116 }
117
118 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
120 self.fetch = fetch;
121 self
122 }
123
124 pub fn with_round_robin_repartition(
134 mut self,
135 enable_round_robin_repartition: bool,
136 ) -> Self {
137 self.enable_round_robin_repartition = enable_round_robin_repartition;
138 self
139 }
140
141 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
143 &self.input
144 }
145
146 pub fn expr(&self) -> &LexOrdering {
148 &self.expr
149 }
150
151 pub fn fetch(&self) -> Option<usize> {
153 self.fetch
154 }
155
156 fn compute_properties(
159 input: &Arc<dyn ExecutionPlan>,
160 ordering: LexOrdering,
161 ) -> PlanProperties {
162 let input_partitions = input.output_partitioning().partition_count();
163 let (drive, scheduling) = if input_partitions > 1 {
164 (EvaluationType::Eager, SchedulingType::Cooperative)
165 } else {
166 (
167 input.properties().evaluation_type,
168 input.properties().scheduling_type,
169 )
170 };
171
172 let mut eq_properties = input.equivalence_properties().clone();
173 eq_properties.clear_per_partition_constants();
174 eq_properties.add_ordering(ordering);
175 PlanProperties::new(
176 eq_properties, Partitioning::UnknownPartitioning(1), input.pipeline_behavior(), input.boundedness(), )
181 .with_evaluation_type(drive)
182 .with_scheduling_type(scheduling)
183 }
184
185 fn with_new_children_and_same_properties(
186 &self,
187 mut children: Vec<Arc<dyn ExecutionPlan>>,
188 ) -> Self {
189 Self {
190 input: children.swap_remove(0),
191 metrics: ExecutionPlanMetricsSet::new(),
192 ..Self::clone(self)
193 }
194 }
195}
196
197impl DisplayAs for SortPreservingMergeExec {
198 fn fmt_as(
199 &self,
200 t: DisplayFormatType,
201 f: &mut std::fmt::Formatter,
202 ) -> std::fmt::Result {
203 match t {
204 DisplayFormatType::Default | DisplayFormatType::Verbose => {
205 write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
206 if let Some(fetch) = self.fetch {
207 write!(f, ", fetch={fetch}")?;
208 };
209
210 Ok(())
211 }
212 DisplayFormatType::TreeRender => {
213 if let Some(fetch) = self.fetch {
214 writeln!(f, "limit={fetch}")?;
215 };
216
217 for (i, e) in self.expr().iter().enumerate() {
218 e.fmt_sql(f)?;
219 if i != self.expr().len() - 1 {
220 write!(f, ", ")?;
221 }
222 }
223
224 Ok(())
225 }
226 }
227 }
228}
229
230impl ExecutionPlan for SortPreservingMergeExec {
231 fn name(&self) -> &'static str {
232 "SortPreservingMergeExec"
233 }
234
235 fn as_any(&self) -> &dyn Any {
237 self
238 }
239
240 fn properties(&self) -> &Arc<PlanProperties> {
241 &self.cache
242 }
243
244 fn fetch(&self) -> Option<usize> {
245 self.fetch
246 }
247
248 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
250 Some(Arc::new(Self {
251 input: Arc::clone(&self.input),
252 expr: self.expr.clone(),
253 metrics: self.metrics.clone(),
254 fetch: limit,
255 cache: Arc::clone(&self.cache),
256 enable_round_robin_repartition: true,
257 }))
258 }
259
260 fn with_preserve_order(
261 &self,
262 preserve_order: bool,
263 ) -> Option<Arc<dyn ExecutionPlan>> {
264 self.input
265 .with_preserve_order(preserve_order)
266 .and_then(|new_input| {
267 Arc::new(self.clone())
268 .with_new_children(vec![new_input])
269 .ok()
270 })
271 }
272
273 fn required_input_distribution(&self) -> Vec<Distribution> {
274 vec![Distribution::UnspecifiedDistribution]
275 }
276
277 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
278 vec![false]
279 }
280
281 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
282 vec![Some(OrderingRequirements::from(self.expr.clone()))]
283 }
284
285 fn maintains_input_order(&self) -> Vec<bool> {
286 vec![true]
287 }
288
289 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
290 vec![&self.input]
291 }
292
293 fn with_new_children(
294 self: Arc<Self>,
295 mut children: Vec<Arc<dyn ExecutionPlan>>,
296 ) -> Result<Arc<dyn ExecutionPlan>> {
297 check_if_same_properties!(self, children);
298 Ok(Arc::new(
299 SortPreservingMergeExec::new(self.expr.clone(), children.swap_remove(0))
300 .with_fetch(self.fetch),
301 ))
302 }
303
304 fn execute(
305 &self,
306 partition: usize,
307 context: Arc<TaskContext>,
308 ) -> Result<SendableRecordBatchStream> {
309 trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
310 assert_eq_or_internal_err!(
311 partition,
312 0,
313 "SortPreservingMergeExec invalid partition {partition}"
314 );
315
316 let input_partitions = self.input.output_partitioning().partition_count();
317 trace!(
318 "Number of input partitions of SortPreservingMergeExec::execute: {input_partitions}"
319 );
320 let schema = self.schema();
321
322 let reservation =
323 MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
324 .register(&context.runtime_env().memory_pool);
325
326 match input_partitions {
327 0 => internal_err!(
328 "SortPreservingMergeExec requires at least one input partition"
329 ),
330 1 => match self.fetch {
331 Some(fetch) => {
332 let stream = self.input.execute(0, context)?;
333 debug!(
334 "Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"
335 );
336 Ok(Box::pin(LimitStream::new(
337 stream,
338 0,
339 Some(fetch),
340 BaselineMetrics::new(&self.metrics, partition),
341 )))
342 }
343 None => {
344 let stream = self.input.execute(0, context);
345 debug!(
346 "Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"
347 );
348 stream
349 }
350 },
351 _ => {
352 let receivers = (0..input_partitions)
353 .map(|partition| {
354 let stream =
355 self.input.execute(partition, Arc::clone(&context))?;
356 Ok(spawn_buffered(stream, 1))
357 })
358 .collect::<Result<_>>()?;
359
360 debug!(
361 "Done setting up sender-receiver for SortPreservingMergeExec::execute"
362 );
363
364 let result = StreamingMergeBuilder::new()
365 .with_streams(receivers)
366 .with_schema(schema)
367 .with_expressions(&self.expr)
368 .with_metrics(BaselineMetrics::new(&self.metrics, partition))
369 .with_batch_size(context.session_config().batch_size())
370 .with_fetch(self.fetch)
371 .with_reservation(reservation)
372 .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
373 .build()?;
374
375 debug!(
376 "Got stream result from SortPreservingMergeStream::new_from_receivers"
377 );
378
379 Ok(result)
380 }
381 }
382 }
383
384 fn metrics(&self) -> Option<MetricsSet> {
385 Some(self.metrics.clone_inner())
386 }
387
388 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
389 self.input.partition_statistics(None)
390 }
391
392 fn supports_limit_pushdown(&self) -> bool {
393 true
394 }
395
396 fn try_swapping_with_projection(
400 &self,
401 projection: &ProjectionExec,
402 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
403 if projection.expr().len() >= projection.input().schema().fields().len() {
405 return Ok(None);
406 }
407
408 let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
409 else {
410 return Ok(None);
411 };
412
413 Ok(Some(Arc::new(
414 SortPreservingMergeExec::new(
415 updated_exprs,
416 make_with_child(projection, self.input())?,
417 )
418 .with_fetch(self.fetch()),
419 )))
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use std::collections::HashSet;
426 use std::fmt::Formatter;
427 use std::pin::Pin;
428 use std::sync::Mutex;
429 use std::task::{Context, Poll, Waker, ready};
430 use std::time::Duration;
431
432 use super::*;
433 use crate::coalesce_partitions::CoalescePartitionsExec;
434 use crate::execution_plan::{Boundedness, EmissionType};
435 use crate::expressions::col;
436 use crate::metrics::{MetricValue, Timestamp};
437 use crate::repartition::RepartitionExec;
438 use crate::sorts::sort::SortExec;
439 use crate::stream::RecordBatchReceiverStream;
440 use crate::test::TestMemoryExec;
441 use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
442 use crate::test::{self, assert_is_pending, make_partition};
443 use crate::{collect, common};
444
445 use arrow::array::{
446 ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
447 TimestampNanosecondArray,
448 };
449 use arrow::compute::SortOptions;
450 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
451 use datafusion_common::test_util::batches_to_string;
452 use datafusion_common::{assert_batches_eq, exec_err};
453 use datafusion_common_runtime::SpawnedTask;
454 use datafusion_execution::RecordBatchStream;
455 use datafusion_execution::config::SessionConfig;
456 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
457 use datafusion_physical_expr::EquivalenceProperties;
458 use datafusion_physical_expr::expressions::Column;
459 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
460 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
461
462 use futures::{FutureExt, Stream, StreamExt};
463 use insta::assert_snapshot;
464 use tokio::time::timeout;
465
466 fn generate_task_ctx_for_round_robin_tie_breaker(
469 target_batch_size: usize,
470 ) -> Result<Arc<TaskContext>> {
471 let runtime = RuntimeEnvBuilder::new()
472 .with_memory_limit(20_000_000, 1.0)
473 .build_arc()?;
474 let mut config = SessionConfig::new();
475 config.options_mut().execution.batch_size = target_batch_size;
476 let task_ctx = TaskContext::default()
477 .with_runtime(runtime)
478 .with_session_config(config);
479 Ok(Arc::new(task_ctx))
480 }
481 fn generate_spm_for_round_robin_tie_breaker(
484 enable_round_robin_repartition: bool,
485 ) -> Result<Arc<SortPreservingMergeExec>> {
486 let row_size = 12500;
487 let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
488 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
489 let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
490 let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
491 let schema = rb.schema();
492
493 let rbs = std::iter::repeat_n(rb, 1024).collect::<Vec<_>>();
494 let sort = [
495 PhysicalSortExpr {
496 expr: col("b", &schema)?,
497 options: Default::default(),
498 },
499 PhysicalSortExpr {
500 expr: col("c", &schema)?,
501 options: Default::default(),
502 },
503 ]
504 .into();
505
506 let repartition_exec = RepartitionExec::try_new(
507 TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
508 Partitioning::RoundRobinBatch(2),
509 )?;
510 let spm = SortPreservingMergeExec::new(sort, Arc::new(repartition_exec))
511 .with_round_robin_repartition(enable_round_robin_repartition);
512 Ok(Arc::new(spm))
513 }
514
515 #[tokio::test(flavor = "multi_thread")]
521 async fn test_round_robin_tie_breaker_success() -> Result<()> {
522 let target_batch_size = 12500;
523 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(target_batch_size)?;
524 let spm = generate_spm_for_round_robin_tie_breaker(true)?;
525 let _collected = collect(spm, task_ctx).await?;
526 Ok(())
527 }
528
529 #[tokio::test(flavor = "multi_thread")]
535 async fn test_round_robin_tie_breaker_fail() -> Result<()> {
536 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(8192)?;
537 let spm = generate_spm_for_round_robin_tie_breaker(false)?;
538 let _err = collect(spm, task_ctx).await.unwrap_err();
539 Ok(())
540 }
541
542 #[tokio::test]
543 async fn test_merge_interleave() {
544 let task_ctx = Arc::new(TaskContext::default());
545 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
546 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
547 Some("a"),
548 Some("c"),
549 Some("e"),
550 Some("g"),
551 Some("j"),
552 ]));
553 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
554 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
555
556 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
557 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
558 Some("b"),
559 Some("d"),
560 Some("f"),
561 Some("h"),
562 Some("j"),
563 ]));
564 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
565 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
566
567 _test_merge(
568 &[vec![b1], vec![b2]],
569 &[
570 "+----+---+-------------------------------+",
571 "| a | b | c |",
572 "+----+---+-------------------------------+",
573 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
574 "| 10 | b | 1970-01-01T00:00:00.000000004 |",
575 "| 2 | c | 1970-01-01T00:00:00.000000007 |",
576 "| 20 | d | 1970-01-01T00:00:00.000000006 |",
577 "| 7 | e | 1970-01-01T00:00:00.000000006 |",
578 "| 70 | f | 1970-01-01T00:00:00.000000002 |",
579 "| 9 | g | 1970-01-01T00:00:00.000000005 |",
580 "| 90 | h | 1970-01-01T00:00:00.000000002 |",
581 "| 30 | j | 1970-01-01T00:00:00.000000006 |", "| 3 | j | 1970-01-01T00:00:00.000000008 |",
583 "+----+---+-------------------------------+",
584 ],
585 task_ctx,
586 )
587 .await;
588 }
589
590 #[tokio::test]
591 async fn test_merge_some_overlap() {
592 let task_ctx = Arc::new(TaskContext::default());
593 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
594 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
595 Some("a"),
596 Some("b"),
597 Some("c"),
598 Some("d"),
599 Some("e"),
600 ]));
601 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
602 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
603
604 let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
605 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
606 Some("c"),
607 Some("d"),
608 Some("e"),
609 Some("f"),
610 Some("g"),
611 ]));
612 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
613 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
614
615 _test_merge(
616 &[vec![b1], vec![b2]],
617 &[
618 "+-----+---+-------------------------------+",
619 "| a | b | c |",
620 "+-----+---+-------------------------------+",
621 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
622 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
623 "| 70 | c | 1970-01-01T00:00:00.000000004 |",
624 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
625 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
626 "| 90 | d | 1970-01-01T00:00:00.000000006 |",
627 "| 30 | e | 1970-01-01T00:00:00.000000002 |",
628 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
629 "| 100 | f | 1970-01-01T00:00:00.000000002 |",
630 "| 110 | g | 1970-01-01T00:00:00.000000006 |",
631 "+-----+---+-------------------------------+",
632 ],
633 task_ctx,
634 )
635 .await;
636 }
637
638 #[tokio::test]
639 async fn test_merge_no_overlap() {
640 let task_ctx = Arc::new(TaskContext::default());
641 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
642 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
643 Some("a"),
644 Some("b"),
645 Some("c"),
646 Some("d"),
647 Some("e"),
648 ]));
649 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
650 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
651
652 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
653 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
654 Some("f"),
655 Some("g"),
656 Some("h"),
657 Some("i"),
658 Some("j"),
659 ]));
660 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
661 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
662
663 _test_merge(
664 &[vec![b1], vec![b2]],
665 &[
666 "+----+---+-------------------------------+",
667 "| a | b | c |",
668 "+----+---+-------------------------------+",
669 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
670 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
671 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
672 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
673 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
674 "| 10 | f | 1970-01-01T00:00:00.000000004 |",
675 "| 20 | g | 1970-01-01T00:00:00.000000006 |",
676 "| 70 | h | 1970-01-01T00:00:00.000000002 |",
677 "| 90 | i | 1970-01-01T00:00:00.000000002 |",
678 "| 30 | j | 1970-01-01T00:00:00.000000006 |",
679 "+----+---+-------------------------------+",
680 ],
681 task_ctx,
682 )
683 .await;
684 }
685
686 #[tokio::test]
687 async fn test_merge_three_partitions() {
688 let task_ctx = Arc::new(TaskContext::default());
689 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
690 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
691 Some("a"),
692 Some("b"),
693 Some("c"),
694 Some("d"),
695 Some("f"),
696 ]));
697 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
698 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
699
700 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
701 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
702 Some("e"),
703 Some("g"),
704 Some("h"),
705 Some("i"),
706 Some("j"),
707 ]));
708 let c: ArrayRef =
709 Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
710 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
711
712 let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
713 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
714 Some("f"),
715 Some("g"),
716 Some("h"),
717 Some("i"),
718 Some("j"),
719 ]));
720 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
721 let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
722
723 _test_merge(
724 &[vec![b1], vec![b2], vec![b3]],
725 &[
726 "+-----+---+-------------------------------+",
727 "| a | b | c |",
728 "+-----+---+-------------------------------+",
729 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
730 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
731 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
732 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
733 "| 10 | e | 1970-01-01T00:00:00.000000040 |",
734 "| 100 | f | 1970-01-01T00:00:00.000000004 |",
735 "| 3 | f | 1970-01-01T00:00:00.000000008 |",
736 "| 200 | g | 1970-01-01T00:00:00.000000006 |",
737 "| 20 | g | 1970-01-01T00:00:00.000000060 |",
738 "| 700 | h | 1970-01-01T00:00:00.000000002 |",
739 "| 70 | h | 1970-01-01T00:00:00.000000020 |",
740 "| 900 | i | 1970-01-01T00:00:00.000000002 |",
741 "| 90 | i | 1970-01-01T00:00:00.000000020 |",
742 "| 300 | j | 1970-01-01T00:00:00.000000006 |",
743 "| 30 | j | 1970-01-01T00:00:00.000000060 |",
744 "+-----+---+-------------------------------+",
745 ],
746 task_ctx,
747 )
748 .await;
749 }
750
751 async fn _test_merge(
752 partitions: &[Vec<RecordBatch>],
753 exp: &[&str],
754 context: Arc<TaskContext>,
755 ) {
756 let schema = partitions[0][0].schema();
757 let sort = [
758 PhysicalSortExpr {
759 expr: col("b", &schema).unwrap(),
760 options: Default::default(),
761 },
762 PhysicalSortExpr {
763 expr: col("c", &schema).unwrap(),
764 options: Default::default(),
765 },
766 ]
767 .into();
768 let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
769 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
770
771 let collected = collect(merge, context).await.unwrap();
772 assert_batches_eq!(exp, collected.as_slice());
773 }
774
775 async fn sorted_merge(
776 input: Arc<dyn ExecutionPlan>,
777 sort: LexOrdering,
778 context: Arc<TaskContext>,
779 ) -> RecordBatch {
780 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
781 let mut result = collect(merge, context).await.unwrap();
782 assert_eq!(result.len(), 1);
783 result.remove(0)
784 }
785
786 async fn partition_sort(
787 input: Arc<dyn ExecutionPlan>,
788 sort: LexOrdering,
789 context: Arc<TaskContext>,
790 ) -> RecordBatch {
791 let sort_exec =
792 Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
793 sorted_merge(sort_exec, sort, context).await
794 }
795
796 async fn basic_sort(
797 src: Arc<dyn ExecutionPlan>,
798 sort: LexOrdering,
799 context: Arc<TaskContext>,
800 ) -> RecordBatch {
801 let merge = Arc::new(CoalescePartitionsExec::new(src));
802 let sort_exec = Arc::new(SortExec::new(sort, merge));
803 let mut result = collect(sort_exec, context).await.unwrap();
804 assert_eq!(result.len(), 1);
805 result.remove(0)
806 }
807
808 #[tokio::test]
809 async fn test_partition_sort() -> Result<()> {
810 let task_ctx = Arc::new(TaskContext::default());
811 let partitions = 4;
812 let csv = test::scan_partitioned(partitions);
813 let schema = csv.schema();
814
815 let sort: LexOrdering = [PhysicalSortExpr {
816 expr: col("i", &schema)?,
817 options: SortOptions {
818 descending: true,
819 nulls_first: true,
820 },
821 }]
822 .into();
823
824 let basic =
825 basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
826 let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
827
828 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
829 .unwrap()
830 .to_string();
831 let partition = arrow::util::pretty::pretty_format_batches(&[partition])
832 .unwrap()
833 .to_string();
834
835 assert_eq!(
836 basic, partition,
837 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
838 );
839
840 Ok(())
841 }
842
843 fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
845 let batches = sorted.num_rows().div_ceil(batch_size);
846
847 (0..batches)
849 .map(|batch_idx| {
850 let columns = (0..sorted.num_columns())
851 .map(|column_idx| {
852 let length =
853 batch_size.min(sorted.num_rows() - batch_idx * batch_size);
854
855 sorted
856 .column(column_idx)
857 .slice(batch_idx * batch_size, length)
858 })
859 .collect();
860
861 RecordBatch::try_new(sorted.schema(), columns).unwrap()
862 })
863 .collect()
864 }
865
866 async fn sorted_partitioned_input(
867 sort: LexOrdering,
868 sizes: &[usize],
869 context: Arc<TaskContext>,
870 ) -> Result<Arc<dyn ExecutionPlan>> {
871 let partitions = 4;
872 let csv = test::scan_partitioned(partitions);
873
874 let sorted = basic_sort(csv, sort, context).await;
875 let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
876
877 TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
878 }
879
880 #[tokio::test]
881 async fn test_partition_sort_streaming_input() -> Result<()> {
882 let task_ctx = Arc::new(TaskContext::default());
883 let schema = make_partition(11).schema();
884 let sort: LexOrdering = [PhysicalSortExpr {
885 expr: col("i", &schema)?,
886 options: Default::default(),
887 }]
888 .into();
889
890 let input =
891 sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
892 .await?;
893 let basic =
894 basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
895 let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
896
897 assert_eq!(basic.num_rows(), 1200);
898 assert_eq!(partition.num_rows(), 1200);
899
900 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
901 let partition =
902 arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
903
904 assert_eq!(basic, partition);
905
906 Ok(())
907 }
908
909 #[tokio::test]
910 async fn test_partition_sort_streaming_input_output() -> Result<()> {
911 let schema = make_partition(11).schema();
912 let sort: LexOrdering = [PhysicalSortExpr {
913 expr: col("i", &schema)?,
914 options: Default::default(),
915 }]
916 .into();
917
918 let task_ctx = Arc::new(TaskContext::default());
920 let input =
921 sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
922 .await?;
923 let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
924
925 let task_ctx = TaskContext::default()
927 .with_session_config(SessionConfig::new().with_batch_size(23));
928 let task_ctx = Arc::new(task_ctx);
929
930 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
931 let merged = collect(merge, task_ctx).await?;
932
933 assert_eq!(merged.len(), 53);
934 assert_eq!(basic.num_rows(), 1200);
935 assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
936
937 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
938 let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
939
940 assert_eq!(basic, partition);
941
942 Ok(())
943 }
944
945 #[tokio::test]
946 async fn test_nulls() {
947 let task_ctx = Arc::new(TaskContext::default());
948 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
949 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
950 None,
951 Some("a"),
952 Some("b"),
953 Some("d"),
954 Some("e"),
955 ]));
956 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
957 Some(8),
958 None,
959 Some(6),
960 None,
961 Some(4),
962 ]));
963 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
964
965 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
966 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
967 None,
968 Some("b"),
969 Some("g"),
970 Some("h"),
971 Some("i"),
972 ]));
973 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
974 Some(8),
975 None,
976 Some(5),
977 None,
978 Some(4),
979 ]));
980 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
981 let schema = b1.schema();
982
983 let sort = [
984 PhysicalSortExpr {
985 expr: col("b", &schema).unwrap(),
986 options: SortOptions {
987 descending: false,
988 nulls_first: true,
989 },
990 },
991 PhysicalSortExpr {
992 expr: col("c", &schema).unwrap(),
993 options: SortOptions {
994 descending: false,
995 nulls_first: false,
996 },
997 },
998 ]
999 .into();
1000 let exec =
1001 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1002 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1003
1004 let collected = collect(merge, task_ctx).await.unwrap();
1005 assert_eq!(collected.len(), 1);
1006
1007 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1008 +---+---+-------------------------------+
1009 | a | b | c |
1010 +---+---+-------------------------------+
1011 | 1 | | 1970-01-01T00:00:00.000000008 |
1012 | 1 | | 1970-01-01T00:00:00.000000008 |
1013 | 2 | a | |
1014 | 7 | b | 1970-01-01T00:00:00.000000006 |
1015 | 2 | b | |
1016 | 9 | d | |
1017 | 3 | e | 1970-01-01T00:00:00.000000004 |
1018 | 3 | g | 1970-01-01T00:00:00.000000005 |
1019 | 4 | h | |
1020 | 5 | i | 1970-01-01T00:00:00.000000004 |
1021 +---+---+-------------------------------+
1022 ");
1023 }
1024
1025 #[tokio::test]
1026 async fn test_sort_merge_single_partition_with_fetch() {
1027 let task_ctx = Arc::new(TaskContext::default());
1028 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1029 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1030 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1031 let schema = batch.schema();
1032
1033 let sort = [PhysicalSortExpr {
1034 expr: col("b", &schema).unwrap(),
1035 options: SortOptions {
1036 descending: false,
1037 nulls_first: true,
1038 },
1039 }]
1040 .into();
1041 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1042 let merge =
1043 Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1044
1045 let collected = collect(merge, task_ctx).await.unwrap();
1046 assert_eq!(collected.len(), 1);
1047
1048 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1049 +---+---+
1050 | a | b |
1051 +---+---+
1052 | 1 | a |
1053 | 2 | b |
1054 +---+---+
1055 ");
1056 }
1057
1058 #[tokio::test]
1059 async fn test_sort_merge_single_partition_without_fetch() {
1060 let task_ctx = Arc::new(TaskContext::default());
1061 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1062 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1063 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1064 let schema = batch.schema();
1065
1066 let sort = [PhysicalSortExpr {
1067 expr: col("b", &schema).unwrap(),
1068 options: SortOptions {
1069 descending: false,
1070 nulls_first: true,
1071 },
1072 }]
1073 .into();
1074 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1075 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1076
1077 let collected = collect(merge, task_ctx).await.unwrap();
1078 assert_eq!(collected.len(), 1);
1079
1080 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1081 +---+---+
1082 | a | b |
1083 +---+---+
1084 | 1 | a |
1085 | 2 | b |
1086 | 7 | c |
1087 | 9 | d |
1088 | 3 | e |
1089 +---+---+
1090 ");
1091 }
1092
1093 #[tokio::test]
1094 async fn test_async() -> Result<()> {
1095 let task_ctx = Arc::new(TaskContext::default());
1096 let schema = make_partition(11).schema();
1097 let sort: LexOrdering = [PhysicalSortExpr {
1098 expr: col("i", &schema).unwrap(),
1099 options: SortOptions::default(),
1100 }]
1101 .into();
1102
1103 let batches =
1104 sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1105 .await?;
1106
1107 let partition_count = batches.output_partitioning().partition_count();
1108 let mut streams = Vec::with_capacity(partition_count);
1109
1110 for partition in 0..partition_count {
1111 let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1112
1113 let sender = builder.tx();
1114
1115 let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1116 builder.spawn(async move {
1117 while let Some(batch) = stream.next().await {
1118 sender.send(batch).await.unwrap();
1119 tokio::time::sleep(Duration::from_millis(10)).await;
1121 }
1122
1123 Ok(())
1124 });
1125
1126 streams.push(builder.build());
1127 }
1128
1129 let metrics = ExecutionPlanMetricsSet::new();
1130 let reservation =
1131 MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1132
1133 let fetch = None;
1134 let merge_stream = StreamingMergeBuilder::new()
1135 .with_streams(streams)
1136 .with_schema(batches.schema())
1137 .with_expressions(&sort)
1138 .with_metrics(BaselineMetrics::new(&metrics, 0))
1139 .with_batch_size(task_ctx.session_config().batch_size())
1140 .with_fetch(fetch)
1141 .with_reservation(reservation)
1142 .build()?;
1143
1144 let mut merged = common::collect(merge_stream).await.unwrap();
1145
1146 assert_eq!(merged.len(), 1);
1147 let merged = merged.remove(0);
1148 let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1149
1150 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1151 .unwrap()
1152 .to_string();
1153 let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1154 .unwrap()
1155 .to_string();
1156
1157 assert_eq!(
1158 basic, partition,
1159 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1160 );
1161
1162 Ok(())
1163 }
1164
1165 #[tokio::test]
1166 async fn test_merge_metrics() {
1167 let task_ctx = Arc::new(TaskContext::default());
1168 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1169 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1170 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1171
1172 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1173 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1174 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1175
1176 let schema = b1.schema();
1177 let sort = [PhysicalSortExpr {
1178 expr: col("b", &schema).unwrap(),
1179 options: Default::default(),
1180 }]
1181 .into();
1182 let exec =
1183 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1184 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1185
1186 let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1187 .await
1188 .unwrap();
1189 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1190 +----+---+
1191 | a | b |
1192 +----+---+
1193 | 1 | a |
1194 | 10 | b |
1195 | 2 | c |
1196 | 20 | d |
1197 +----+---+
1198 ");
1199
1200 let metrics = merge.metrics().unwrap();
1202
1203 assert_eq!(metrics.output_rows().unwrap(), 4);
1204 assert!(metrics.elapsed_compute().unwrap() > 0);
1205
1206 let mut saw_start = false;
1207 let mut saw_end = false;
1208 metrics.iter().for_each(|m| match m.value() {
1209 MetricValue::StartTimestamp(ts) => {
1210 saw_start = true;
1211 assert!(nanos_from_timestamp(ts) > 0);
1212 }
1213 MetricValue::EndTimestamp(ts) => {
1214 saw_end = true;
1215 assert!(nanos_from_timestamp(ts) > 0);
1216 }
1217 _ => {}
1218 });
1219
1220 assert!(saw_start);
1221 assert!(saw_end);
1222 }
1223
1224 fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1225 ts.value().unwrap().timestamp_nanos_opt().unwrap()
1226 }
1227
1228 #[tokio::test]
1229 async fn test_drop_cancel() -> Result<()> {
1230 let task_ctx = Arc::new(TaskContext::default());
1231 let schema =
1232 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1233
1234 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1235 let refs = blocking_exec.refs();
1236 let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1237 [PhysicalSortExpr {
1238 expr: col("a", &schema)?,
1239 options: SortOptions::default(),
1240 }]
1241 .into(),
1242 blocking_exec,
1243 ));
1244
1245 let fut = collect(sort_preserving_merge_exec, task_ctx);
1246 let mut fut = fut.boxed();
1247
1248 assert_is_pending(&mut fut);
1249 drop(fut);
1250 assert_strong_count_converges_to_zero(refs).await;
1251
1252 Ok(())
1253 }
1254
1255 #[tokio::test]
1256 async fn test_stable_sort() {
1257 let task_ctx = Arc::new(TaskContext::default());
1258
1259 let partitions: Vec<Vec<RecordBatch>> = (0..10)
1267 .map(|batch_number| {
1268 let batch_number: Int32Array =
1269 vec![Some(batch_number), Some(batch_number)]
1270 .into_iter()
1271 .collect();
1272 let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1273
1274 let batch = RecordBatch::try_from_iter(vec![
1275 ("batch_number", Arc::new(batch_number) as ArrayRef),
1276 ("value", Arc::new(value) as ArrayRef),
1277 ])
1278 .unwrap();
1279
1280 vec![batch]
1281 })
1282 .collect();
1283
1284 let schema = partitions[0][0].schema();
1285
1286 let sort = [PhysicalSortExpr {
1287 expr: col("value", &schema).unwrap(),
1288 options: SortOptions {
1289 descending: false,
1290 nulls_first: true,
1291 },
1292 }]
1293 .into();
1294
1295 let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1296 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1297
1298 let collected = collect(merge, task_ctx).await.unwrap();
1299 assert_eq!(collected.len(), 1);
1300
1301 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1305 +--------------+-------+
1306 | batch_number | value |
1307 +--------------+-------+
1308 | 0 | A |
1309 | 1 | A |
1310 | 2 | A |
1311 | 3 | A |
1312 | 4 | A |
1313 | 5 | A |
1314 | 6 | A |
1315 | 7 | A |
1316 | 8 | A |
1317 | 9 | A |
1318 | 0 | B |
1319 | 1 | B |
1320 | 2 | B |
1321 | 3 | B |
1322 | 4 | B |
1323 | 5 | B |
1324 | 6 | B |
1325 | 7 | B |
1326 | 8 | B |
1327 | 9 | B |
1328 +--------------+-------+
1329 ");
1330 }
1331
1332 #[derive(Debug)]
1333 struct CongestionState {
1334 wakers: Vec<Waker>,
1335 unpolled_partitions: HashSet<usize>,
1336 }
1337
1338 #[derive(Debug)]
1339 struct Congestion {
1340 congestion_state: Mutex<CongestionState>,
1341 }
1342
1343 impl Congestion {
1344 fn new(partition_count: usize) -> Self {
1345 Congestion {
1346 congestion_state: Mutex::new(CongestionState {
1347 wakers: vec![],
1348 unpolled_partitions: (0usize..partition_count).collect(),
1349 }),
1350 }
1351 }
1352
1353 fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1354 let mut state = self.congestion_state.lock().unwrap();
1355
1356 state.unpolled_partitions.remove(&partition);
1357
1358 if state.unpolled_partitions.is_empty() {
1359 state.wakers.iter().for_each(|w| w.wake_by_ref());
1360 state.wakers.clear();
1361 Poll::Ready(())
1362 } else {
1363 state.wakers.push(cx.waker().clone());
1364 Poll::Pending
1365 }
1366 }
1367 }
1368
1369 #[derive(Debug, Clone)]
1372 struct CongestedExec {
1373 schema: Schema,
1374 cache: Arc<PlanProperties>,
1375 congestion: Arc<Congestion>,
1376 }
1377
1378 impl CongestedExec {
1379 fn compute_properties(schema: SchemaRef) -> PlanProperties {
1380 let columns = schema
1381 .fields
1382 .iter()
1383 .enumerate()
1384 .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1385 .collect::<Vec<_>>();
1386 let mut eq_properties = EquivalenceProperties::new(schema);
1387 eq_properties.add_ordering(
1388 columns
1389 .iter()
1390 .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1391 );
1392 PlanProperties::new(
1393 eq_properties,
1394 Partitioning::Hash(columns, 3),
1395 EmissionType::Incremental,
1396 Boundedness::Unbounded {
1397 requires_infinite_memory: false,
1398 },
1399 )
1400 }
1401 }
1402
1403 impl ExecutionPlan for CongestedExec {
1404 fn name(&self) -> &'static str {
1405 Self::static_name()
1406 }
1407 fn as_any(&self) -> &dyn Any {
1408 self
1409 }
1410 fn properties(&self) -> &Arc<PlanProperties> {
1411 &self.cache
1412 }
1413 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1414 vec![]
1415 }
1416 fn with_new_children(
1417 self: Arc<Self>,
1418 _: Vec<Arc<dyn ExecutionPlan>>,
1419 ) -> Result<Arc<dyn ExecutionPlan>> {
1420 Ok(self)
1421 }
1422 fn execute(
1423 &self,
1424 partition: usize,
1425 _context: Arc<TaskContext>,
1426 ) -> Result<SendableRecordBatchStream> {
1427 Ok(Box::pin(CongestedStream {
1428 schema: Arc::new(self.schema.clone()),
1429 none_polled_once: false,
1430 congestion: Arc::clone(&self.congestion),
1431 partition,
1432 }))
1433 }
1434 }
1435
1436 impl DisplayAs for CongestedExec {
1437 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1438 match t {
1439 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1440 write!(f, "CongestedExec",).unwrap()
1441 }
1442 DisplayFormatType::TreeRender => {
1443 write!(f, "").unwrap()
1445 }
1446 }
1447 Ok(())
1448 }
1449 }
1450
1451 #[derive(Debug)]
1454 pub struct CongestedStream {
1455 schema: SchemaRef,
1456 none_polled_once: bool,
1457 congestion: Arc<Congestion>,
1458 partition: usize,
1459 }
1460
1461 impl Stream for CongestedStream {
1462 type Item = Result<RecordBatch>;
1463 fn poll_next(
1464 mut self: Pin<&mut Self>,
1465 cx: &mut Context<'_>,
1466 ) -> Poll<Option<Self::Item>> {
1467 match self.partition {
1468 0 => {
1469 let _ = self.congestion.check_congested(self.partition, cx);
1470 if self.none_polled_once {
1471 panic!("Exhausted stream is polled more than once")
1472 } else {
1473 self.none_polled_once = true;
1474 Poll::Ready(None)
1475 }
1476 }
1477 _ => {
1478 ready!(self.congestion.check_congested(self.partition, cx));
1479 Poll::Ready(None)
1480 }
1481 }
1482 }
1483 }
1484
1485 impl RecordBatchStream for CongestedStream {
1486 fn schema(&self) -> SchemaRef {
1487 Arc::clone(&self.schema)
1488 }
1489 }
1490
1491 #[tokio::test]
1492 async fn test_spm_congestion() -> Result<()> {
1493 let task_ctx = Arc::new(TaskContext::default());
1494 let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1495 let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1496 let &partition_count = match properties.output_partitioning() {
1497 Partitioning::RoundRobinBatch(partitions) => partitions,
1498 Partitioning::Hash(_, partitions) => partitions,
1499 Partitioning::UnknownPartitioning(partitions) => partitions,
1500 };
1501 let source = CongestedExec {
1502 schema: schema.clone(),
1503 cache: Arc::new(properties),
1504 congestion: Arc::new(Congestion::new(partition_count)),
1505 };
1506 let spm = SortPreservingMergeExec::new(
1507 [PhysicalSortExpr::new_default(Arc::new(Column::new(
1508 "c1", 0,
1509 )))]
1510 .into(),
1511 Arc::new(source),
1512 );
1513 let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1514
1515 let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1516 match result {
1517 Ok(Ok(Ok(_batches))) => Ok(()),
1518 Ok(Ok(Err(e))) => Err(e),
1519 Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1520 Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1521 }
1522 }
1523}