1use std::sync::Arc;
21
22use crate::common::spawn_buffered;
23use crate::limit::LimitStream;
24use crate::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
25use crate::projection::{ProjectionExec, make_with_child, update_ordering};
26use crate::sorts::streaming_merge::StreamingMergeBuilder;
27use crate::{
28 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
29 Partitioning, PlanProperties, SendableRecordBatchStream, Statistics,
30 check_if_same_properties,
31};
32
33use datafusion_common::{Result, assert_eq_or_internal_err, internal_err};
34use datafusion_execution::TaskContext;
35use datafusion_execution::memory_pool::MemoryConsumer;
36use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
37
38use crate::execution_plan::{EvaluationType, SchedulingType};
39use log::{debug, trace};
40
41#[derive(Debug, Clone)]
86pub struct SortPreservingMergeExec {
87 input: Arc<dyn ExecutionPlan>,
89 expr: LexOrdering,
91 metrics: ExecutionPlanMetricsSet,
93 fetch: Option<usize>,
95 cache: Arc<PlanProperties>,
97 enable_round_robin_repartition: bool,
101}
102
103impl SortPreservingMergeExec {
104 pub fn new(expr: LexOrdering, input: Arc<dyn ExecutionPlan>) -> Self {
106 let cache = Self::compute_properties(&input, expr.clone());
107 Self {
108 input,
109 expr,
110 metrics: ExecutionPlanMetricsSet::new(),
111 fetch: None,
112 cache: Arc::new(cache),
113 enable_round_robin_repartition: true,
114 }
115 }
116
117 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
119 self.fetch = fetch;
120 self
121 }
122
123 pub fn with_round_robin_repartition(
133 mut self,
134 enable_round_robin_repartition: bool,
135 ) -> Self {
136 self.enable_round_robin_repartition = enable_round_robin_repartition;
137 self
138 }
139
140 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
142 &self.input
143 }
144
145 pub fn expr(&self) -> &LexOrdering {
147 &self.expr
148 }
149
150 pub fn fetch(&self) -> Option<usize> {
152 self.fetch
153 }
154
155 fn compute_properties(
158 input: &Arc<dyn ExecutionPlan>,
159 ordering: LexOrdering,
160 ) -> PlanProperties {
161 let input_partitions = input.output_partitioning().partition_count();
162 let (drive, scheduling) = if input_partitions > 1 {
163 (EvaluationType::Eager, SchedulingType::Cooperative)
164 } else {
165 (
166 input.properties().evaluation_type,
167 input.properties().scheduling_type,
168 )
169 };
170
171 let mut eq_properties = input.equivalence_properties().clone();
172 eq_properties.clear_per_partition_constants();
173 eq_properties.add_ordering(ordering);
174 PlanProperties::new(
175 eq_properties, Partitioning::UnknownPartitioning(1), input.pipeline_behavior(), input.boundedness(), )
180 .with_evaluation_type(drive)
181 .with_scheduling_type(scheduling)
182 }
183
184 fn with_new_children_and_same_properties(
185 &self,
186 mut children: Vec<Arc<dyn ExecutionPlan>>,
187 ) -> Self {
188 Self {
189 input: children.swap_remove(0),
190 metrics: ExecutionPlanMetricsSet::new(),
191 ..Self::clone(self)
192 }
193 }
194}
195
196impl DisplayAs for SortPreservingMergeExec {
197 fn fmt_as(
198 &self,
199 t: DisplayFormatType,
200 f: &mut std::fmt::Formatter,
201 ) -> std::fmt::Result {
202 match t {
203 DisplayFormatType::Default | DisplayFormatType::Verbose => {
204 write!(f, "SortPreservingMergeExec: [{}]", self.expr)?;
205 if let Some(fetch) = self.fetch {
206 write!(f, ", fetch={fetch}")?;
207 };
208
209 Ok(())
210 }
211 DisplayFormatType::TreeRender => {
212 if let Some(fetch) = self.fetch {
213 writeln!(f, "limit={fetch}")?;
214 };
215
216 for (i, e) in self.expr().iter().enumerate() {
217 e.fmt_sql(f)?;
218 if i != self.expr().len() - 1 {
219 write!(f, ", ")?;
220 }
221 }
222
223 Ok(())
224 }
225 }
226 }
227}
228
229impl ExecutionPlan for SortPreservingMergeExec {
230 fn name(&self) -> &'static str {
231 "SortPreservingMergeExec"
232 }
233
234 fn properties(&self) -> &Arc<PlanProperties> {
236 &self.cache
237 }
238
239 fn fetch(&self) -> Option<usize> {
240 self.fetch
241 }
242
243 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
245 Some(Arc::new(Self {
246 input: Arc::clone(&self.input),
247 expr: self.expr.clone(),
248 metrics: self.metrics.clone(),
249 fetch: limit,
250 cache: Arc::clone(&self.cache),
251 enable_round_robin_repartition: self.enable_round_robin_repartition,
252 }))
253 }
254
255 fn with_preserve_order(
256 &self,
257 preserve_order: bool,
258 ) -> Option<Arc<dyn ExecutionPlan>> {
259 self.input
260 .with_preserve_order(preserve_order)
261 .and_then(|new_input| {
262 Arc::new(self.clone())
263 .with_new_children(vec![new_input])
264 .ok()
265 })
266 }
267
268 fn required_input_distribution(&self) -> Vec<Distribution> {
269 vec![Distribution::UnspecifiedDistribution]
270 }
271
272 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
273 vec![false]
274 }
275
276 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
277 vec![Some(OrderingRequirements::from(self.expr.clone()))]
278 }
279
280 fn maintains_input_order(&self) -> Vec<bool> {
281 vec![true]
282 }
283
284 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
285 vec![&self.input]
286 }
287
288 fn with_new_children(
289 self: Arc<Self>,
290 mut children: Vec<Arc<dyn ExecutionPlan>>,
291 ) -> Result<Arc<dyn ExecutionPlan>> {
292 check_if_same_properties!(self, children);
293 Ok(Arc::new(
294 SortPreservingMergeExec::new(self.expr.clone(), children.swap_remove(0))
295 .with_fetch(self.fetch),
296 ))
297 }
298
299 fn execute(
300 &self,
301 partition: usize,
302 context: Arc<TaskContext>,
303 ) -> Result<SendableRecordBatchStream> {
304 trace!("Start SortPreservingMergeExec::execute for partition: {partition}");
305 assert_eq_or_internal_err!(
306 partition,
307 0,
308 "SortPreservingMergeExec invalid partition {partition}"
309 );
310
311 let input_partitions = self.input.output_partitioning().partition_count();
312 trace!(
313 "Number of input partitions of SortPreservingMergeExec::execute: {input_partitions}"
314 );
315 let schema = self.schema();
316
317 let reservation =
318 MemoryConsumer::new(format!("SortPreservingMergeExec[{partition}]"))
319 .register(&context.runtime_env().memory_pool);
320
321 match input_partitions {
322 0 => internal_err!(
323 "SortPreservingMergeExec requires at least one input partition"
324 ),
325 1 => match self.fetch {
326 Some(fetch) => {
327 let stream = self.input.execute(0, context)?;
328 debug!(
329 "Done getting stream for SortPreservingMergeExec::execute with 1 input with {fetch}"
330 );
331 Ok(Box::pin(LimitStream::new(
332 stream,
333 0,
334 Some(fetch),
335 BaselineMetrics::new(&self.metrics, partition),
336 )))
337 }
338 None => {
339 let stream = self.input.execute(0, context);
340 debug!(
341 "Done getting stream for SortPreservingMergeExec::execute with 1 input without fetch"
342 );
343 stream
344 }
345 },
346 _ => {
347 let receivers = (0..input_partitions)
348 .map(|partition| {
349 let stream =
350 self.input.execute(partition, Arc::clone(&context))?;
351 Ok(spawn_buffered(stream, 1))
352 })
353 .collect::<Result<_>>()?;
354
355 debug!(
356 "Done setting up sender-receiver for SortPreservingMergeExec::execute"
357 );
358
359 let result = StreamingMergeBuilder::new()
360 .with_streams(receivers)
361 .with_schema(schema)
362 .with_expressions(&self.expr)
363 .with_metrics(BaselineMetrics::new(&self.metrics, partition))
364 .with_batch_size(context.session_config().batch_size())
365 .with_fetch(self.fetch)
366 .with_reservation(reservation)
367 .with_round_robin_tie_breaker(self.enable_round_robin_repartition)
368 .build()?;
369
370 debug!(
371 "Got stream result from SortPreservingMergeStream::new_from_receivers"
372 );
373
374 Ok(result)
375 }
376 }
377 }
378
379 fn metrics(&self) -> Option<MetricsSet> {
380 Some(self.metrics.clone_inner())
381 }
382
383 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Arc<Statistics>> {
384 self.input.partition_statistics(None)
385 }
386
387 fn supports_limit_pushdown(&self) -> bool {
388 true
389 }
390
391 fn try_swapping_with_projection(
395 &self,
396 projection: &ProjectionExec,
397 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
398 if projection.expr().len() >= projection.input().schema().fields().len() {
400 return Ok(None);
401 }
402
403 let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())?
404 else {
405 return Ok(None);
406 };
407
408 Ok(Some(Arc::new(
409 SortPreservingMergeExec::new(
410 updated_exprs,
411 make_with_child(projection, self.input())?,
412 )
413 .with_fetch(self.fetch()),
414 )))
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use std::collections::HashSet;
421 use std::fmt::Formatter;
422 use std::pin::Pin;
423 use std::sync::Mutex;
424 use std::task::{Context, Poll, Waker, ready};
425 use std::time::Duration;
426
427 use super::*;
428 use crate::coalesce_partitions::CoalescePartitionsExec;
429 use crate::execution_plan::{Boundedness, EmissionType};
430 use crate::expressions::col;
431 use crate::metrics::{MetricValue, Timestamp};
432 use crate::repartition::RepartitionExec;
433 use crate::sorts::sort::SortExec;
434 use crate::stream::RecordBatchReceiverStream;
435 use crate::test::TestMemoryExec;
436 use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero};
437 use crate::test::{self, assert_is_pending, make_partition};
438 use crate::{collect, common};
439
440 use arrow::array::{
441 ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray,
442 TimestampNanosecondArray,
443 };
444 use arrow::compute::SortOptions;
445 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
446 use datafusion_common::test_util::batches_to_string;
447 use datafusion_common::{assert_batches_eq, exec_err};
448 use datafusion_common_runtime::SpawnedTask;
449 use datafusion_execution::RecordBatchStream;
450 use datafusion_execution::config::SessionConfig;
451 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
452 use datafusion_physical_expr::EquivalenceProperties;
453 use datafusion_physical_expr::expressions::Column;
454 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
455 use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr;
456
457 use futures::{FutureExt, Stream, StreamExt};
458 use insta::assert_snapshot;
459 use tokio::time::timeout;
460
461 fn generate_task_ctx_for_round_robin_tie_breaker(
464 target_batch_size: usize,
465 ) -> Result<Arc<TaskContext>> {
466 let runtime = RuntimeEnvBuilder::new()
467 .with_memory_limit(20_000_000, 1.0)
468 .build_arc()?;
469 let mut config = SessionConfig::new();
470 config.options_mut().execution.batch_size = target_batch_size;
471 let task_ctx = TaskContext::default()
472 .with_runtime(runtime)
473 .with_session_config(config);
474 Ok(Arc::new(task_ctx))
475 }
476
477 fn generate_spm_for_round_robin_tie_breaker(
480 enable_round_robin_repartition: bool,
481 ) -> Result<Arc<SortPreservingMergeExec>> {
482 let row_size = 12500;
483 let a: ArrayRef = Arc::new(Int32Array::from(vec![1; row_size]));
484 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"); row_size]));
485 let c: ArrayRef = Arc::new(Int64Array::from_iter(vec![0; row_size]));
486 let rb = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)])?;
487 let schema = rb.schema();
488
489 let rbs = std::iter::repeat_n(rb, 1024).collect::<Vec<_>>();
490 let sort = [
491 PhysicalSortExpr {
492 expr: col("b", &schema)?,
493 options: Default::default(),
494 },
495 PhysicalSortExpr {
496 expr: col("c", &schema)?,
497 options: Default::default(),
498 },
499 ]
500 .into();
501
502 let repartition_exec = RepartitionExec::try_new(
503 TestMemoryExec::try_new_exec(&[rbs], schema, None)?,
504 Partitioning::RoundRobinBatch(2),
505 )?;
506 let spm = SortPreservingMergeExec::new(sort, Arc::new(repartition_exec))
507 .with_round_robin_repartition(enable_round_robin_repartition);
508 Ok(Arc::new(spm))
509 }
510
511 #[tokio::test(flavor = "multi_thread")]
517 async fn test_round_robin_tie_breaker_success() -> Result<()> {
518 let target_batch_size = 12500;
519 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(target_batch_size)?;
520 let spm = generate_spm_for_round_robin_tie_breaker(true)?;
521 let _collected = collect(spm, task_ctx).await?;
522 Ok(())
523 }
524
525 #[tokio::test(flavor = "multi_thread")]
531 async fn test_round_robin_tie_breaker_fail() -> Result<()> {
532 let task_ctx = generate_task_ctx_for_round_robin_tie_breaker(8192)?;
533 let spm = generate_spm_for_round_robin_tie_breaker(false)?;
534 let _err = collect(spm, task_ctx).await.unwrap_err();
535 Ok(())
536 }
537
538 #[tokio::test]
539 async fn test_merge_interleave() {
540 let task_ctx = Arc::new(TaskContext::default());
541 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
542 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
543 Some("a"),
544 Some("c"),
545 Some("e"),
546 Some("g"),
547 Some("j"),
548 ]));
549 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
550 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
551
552 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
553 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
554 Some("b"),
555 Some("d"),
556 Some("f"),
557 Some("h"),
558 Some("j"),
559 ]));
560 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
561 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
562
563 _test_merge(
564 &[vec![b1], vec![b2]],
565 &[
566 "+----+---+-------------------------------+",
567 "| a | b | c |",
568 "+----+---+-------------------------------+",
569 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
570 "| 10 | b | 1970-01-01T00:00:00.000000004 |",
571 "| 2 | c | 1970-01-01T00:00:00.000000007 |",
572 "| 20 | d | 1970-01-01T00:00:00.000000006 |",
573 "| 7 | e | 1970-01-01T00:00:00.000000006 |",
574 "| 70 | f | 1970-01-01T00:00:00.000000002 |",
575 "| 9 | g | 1970-01-01T00:00:00.000000005 |",
576 "| 90 | h | 1970-01-01T00:00:00.000000002 |",
577 "| 30 | j | 1970-01-01T00:00:00.000000006 |", "| 3 | j | 1970-01-01T00:00:00.000000008 |",
579 "+----+---+-------------------------------+",
580 ],
581 task_ctx,
582 )
583 .await;
584 }
585
586 #[tokio::test]
587 async fn test_merge_some_overlap() {
588 let task_ctx = Arc::new(TaskContext::default());
589 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
590 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
591 Some("a"),
592 Some("b"),
593 Some("c"),
594 Some("d"),
595 Some("e"),
596 ]));
597 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
598 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
599
600 let a: ArrayRef = Arc::new(Int32Array::from(vec![70, 90, 30, 100, 110]));
601 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
602 Some("c"),
603 Some("d"),
604 Some("e"),
605 Some("f"),
606 Some("g"),
607 ]));
608 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
609 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
610
611 _test_merge(
612 &[vec![b1], vec![b2]],
613 &[
614 "+-----+---+-------------------------------+",
615 "| a | b | c |",
616 "+-----+---+-------------------------------+",
617 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
618 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
619 "| 70 | c | 1970-01-01T00:00:00.000000004 |",
620 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
621 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
622 "| 90 | d | 1970-01-01T00:00:00.000000006 |",
623 "| 30 | e | 1970-01-01T00:00:00.000000002 |",
624 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
625 "| 100 | f | 1970-01-01T00:00:00.000000002 |",
626 "| 110 | g | 1970-01-01T00:00:00.000000006 |",
627 "+-----+---+-------------------------------+",
628 ],
629 task_ctx,
630 )
631 .await;
632 }
633
634 #[tokio::test]
635 async fn test_merge_no_overlap() {
636 let task_ctx = Arc::new(TaskContext::default());
637 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
638 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
639 Some("a"),
640 Some("b"),
641 Some("c"),
642 Some("d"),
643 Some("e"),
644 ]));
645 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
646 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
647
648 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
649 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
650 Some("f"),
651 Some("g"),
652 Some("h"),
653 Some("i"),
654 Some("j"),
655 ]));
656 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
657 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
658
659 _test_merge(
660 &[vec![b1], vec![b2]],
661 &[
662 "+----+---+-------------------------------+",
663 "| a | b | c |",
664 "+----+---+-------------------------------+",
665 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
666 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
667 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
668 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
669 "| 3 | e | 1970-01-01T00:00:00.000000008 |",
670 "| 10 | f | 1970-01-01T00:00:00.000000004 |",
671 "| 20 | g | 1970-01-01T00:00:00.000000006 |",
672 "| 70 | h | 1970-01-01T00:00:00.000000002 |",
673 "| 90 | i | 1970-01-01T00:00:00.000000002 |",
674 "| 30 | j | 1970-01-01T00:00:00.000000006 |",
675 "+----+---+-------------------------------+",
676 ],
677 task_ctx,
678 )
679 .await;
680 }
681
682 #[tokio::test]
683 async fn test_merge_three_partitions() {
684 let task_ctx = Arc::new(TaskContext::default());
685 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
686 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
687 Some("a"),
688 Some("b"),
689 Some("c"),
690 Some("d"),
691 Some("f"),
692 ]));
693 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![8, 7, 6, 5, 8]));
694 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
695
696 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20, 70, 90, 30]));
697 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
698 Some("e"),
699 Some("g"),
700 Some("h"),
701 Some("i"),
702 Some("j"),
703 ]));
704 let c: ArrayRef =
705 Arc::new(TimestampNanosecondArray::from(vec![40, 60, 20, 20, 60]));
706 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
707
708 let a: ArrayRef = Arc::new(Int32Array::from(vec![100, 200, 700, 900, 300]));
709 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
710 Some("f"),
711 Some("g"),
712 Some("h"),
713 Some("i"),
714 Some("j"),
715 ]));
716 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![4, 6, 2, 2, 6]));
717 let b3 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
718
719 _test_merge(
720 &[vec![b1], vec![b2], vec![b3]],
721 &[
722 "+-----+---+-------------------------------+",
723 "| a | b | c |",
724 "+-----+---+-------------------------------+",
725 "| 1 | a | 1970-01-01T00:00:00.000000008 |",
726 "| 2 | b | 1970-01-01T00:00:00.000000007 |",
727 "| 7 | c | 1970-01-01T00:00:00.000000006 |",
728 "| 9 | d | 1970-01-01T00:00:00.000000005 |",
729 "| 10 | e | 1970-01-01T00:00:00.000000040 |",
730 "| 100 | f | 1970-01-01T00:00:00.000000004 |",
731 "| 3 | f | 1970-01-01T00:00:00.000000008 |",
732 "| 200 | g | 1970-01-01T00:00:00.000000006 |",
733 "| 20 | g | 1970-01-01T00:00:00.000000060 |",
734 "| 700 | h | 1970-01-01T00:00:00.000000002 |",
735 "| 70 | h | 1970-01-01T00:00:00.000000020 |",
736 "| 900 | i | 1970-01-01T00:00:00.000000002 |",
737 "| 90 | i | 1970-01-01T00:00:00.000000020 |",
738 "| 300 | j | 1970-01-01T00:00:00.000000006 |",
739 "| 30 | j | 1970-01-01T00:00:00.000000060 |",
740 "+-----+---+-------------------------------+",
741 ],
742 task_ctx,
743 )
744 .await;
745 }
746
747 async fn _test_merge(
748 partitions: &[Vec<RecordBatch>],
749 exp: &[&str],
750 context: Arc<TaskContext>,
751 ) {
752 let schema = partitions[0][0].schema();
753 let sort = [
754 PhysicalSortExpr {
755 expr: col("b", &schema).unwrap(),
756 options: Default::default(),
757 },
758 PhysicalSortExpr {
759 expr: col("c", &schema).unwrap(),
760 options: Default::default(),
761 },
762 ]
763 .into();
764 let exec = TestMemoryExec::try_new_exec(partitions, schema, None).unwrap();
765 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
766
767 let collected = collect(merge, context).await.unwrap();
768 assert_batches_eq!(exp, collected.as_slice());
769 }
770
771 async fn sorted_merge(
772 input: Arc<dyn ExecutionPlan>,
773 sort: LexOrdering,
774 context: Arc<TaskContext>,
775 ) -> RecordBatch {
776 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
777 let mut result = collect(merge, context).await.unwrap();
778 assert_eq!(result.len(), 1);
779 result.remove(0)
780 }
781
782 async fn partition_sort(
783 input: Arc<dyn ExecutionPlan>,
784 sort: LexOrdering,
785 context: Arc<TaskContext>,
786 ) -> RecordBatch {
787 let sort_exec =
788 Arc::new(SortExec::new(sort.clone(), input).with_preserve_partitioning(true));
789 sorted_merge(sort_exec, sort, context).await
790 }
791
792 async fn basic_sort(
793 src: Arc<dyn ExecutionPlan>,
794 sort: LexOrdering,
795 context: Arc<TaskContext>,
796 ) -> RecordBatch {
797 let merge = Arc::new(CoalescePartitionsExec::new(src));
798 let sort_exec = Arc::new(SortExec::new(sort, merge));
799 let mut result = collect(sort_exec, context).await.unwrap();
800 assert_eq!(result.len(), 1);
801 result.remove(0)
802 }
803
804 #[tokio::test]
805 async fn test_partition_sort() -> Result<()> {
806 let task_ctx = Arc::new(TaskContext::default());
807 let partitions = 4;
808 let csv = test::scan_partitioned(partitions);
809 let schema = csv.schema();
810
811 let sort: LexOrdering = [PhysicalSortExpr {
812 expr: col("i", &schema)?,
813 options: SortOptions {
814 descending: true,
815 nulls_first: true,
816 },
817 }]
818 .into();
819
820 let basic =
821 basic_sort(Arc::clone(&csv), sort.clone(), Arc::clone(&task_ctx)).await;
822 let partition = partition_sort(csv, sort, Arc::clone(&task_ctx)).await;
823
824 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
825 .unwrap()
826 .to_string();
827 let partition = arrow::util::pretty::pretty_format_batches(&[partition])
828 .unwrap()
829 .to_string();
830
831 assert_eq!(
832 basic, partition,
833 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
834 );
835
836 Ok(())
837 }
838
839 fn split_batch(sorted: &RecordBatch, batch_size: usize) -> Vec<RecordBatch> {
841 let batches = sorted.num_rows().div_ceil(batch_size);
842
843 (0..batches)
845 .map(|batch_idx| {
846 let columns = (0..sorted.num_columns())
847 .map(|column_idx| {
848 let length =
849 batch_size.min(sorted.num_rows() - batch_idx * batch_size);
850
851 sorted
852 .column(column_idx)
853 .slice(batch_idx * batch_size, length)
854 })
855 .collect();
856
857 RecordBatch::try_new(sorted.schema(), columns).unwrap()
858 })
859 .collect()
860 }
861
862 async fn sorted_partitioned_input(
863 sort: LexOrdering,
864 sizes: &[usize],
865 context: Arc<TaskContext>,
866 ) -> Result<Arc<dyn ExecutionPlan>> {
867 let partitions = 4;
868 let csv = test::scan_partitioned(partitions);
869
870 let sorted = basic_sort(csv, sort, context).await;
871 let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect();
872
873 TestMemoryExec::try_new_exec(&split, sorted.schema(), None).map(|e| e as _)
874 }
875
876 #[tokio::test]
877 async fn test_partition_sort_streaming_input() -> Result<()> {
878 let task_ctx = Arc::new(TaskContext::default());
879 let schema = make_partition(11).schema();
880 let sort: LexOrdering = [PhysicalSortExpr {
881 expr: col("i", &schema)?,
882 options: Default::default(),
883 }]
884 .into();
885
886 let input =
887 sorted_partitioned_input(sort.clone(), &[10, 3, 11], Arc::clone(&task_ctx))
888 .await?;
889 let basic =
890 basic_sort(Arc::clone(&input), sort.clone(), Arc::clone(&task_ctx)).await;
891 let partition = sorted_merge(input, sort, Arc::clone(&task_ctx)).await;
892
893 assert_eq!(basic.num_rows(), 1200);
894 assert_eq!(partition.num_rows(), 1200);
895
896 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
897 let partition =
898 arrow::util::pretty::pretty_format_batches(&[partition])?.to_string();
899
900 assert_eq!(basic, partition);
901
902 Ok(())
903 }
904
905 #[tokio::test]
906 async fn test_partition_sort_streaming_input_output() -> Result<()> {
907 let schema = make_partition(11).schema();
908 let sort: LexOrdering = [PhysicalSortExpr {
909 expr: col("i", &schema)?,
910 options: Default::default(),
911 }]
912 .into();
913
914 let task_ctx = Arc::new(TaskContext::default());
916 let input =
917 sorted_partitioned_input(sort.clone(), &[10, 5, 13], Arc::clone(&task_ctx))
918 .await?;
919 let basic = basic_sort(Arc::clone(&input), sort.clone(), task_ctx).await;
920
921 let task_ctx = TaskContext::default()
923 .with_session_config(SessionConfig::new().with_batch_size(23));
924 let task_ctx = Arc::new(task_ctx);
925
926 let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
927 let merged = collect(merge, task_ctx).await?;
928
929 assert_eq!(merged.len(), 53);
930 assert_eq!(basic.num_rows(), 1200);
931 assert_eq!(merged.iter().map(|x| x.num_rows()).sum::<usize>(), 1200);
932
933 let basic = arrow::util::pretty::pretty_format_batches(&[basic])?.to_string();
934 let partition = arrow::util::pretty::pretty_format_batches(&merged)?.to_string();
935
936 assert_eq!(basic, partition);
937
938 Ok(())
939 }
940
941 #[tokio::test]
942 async fn test_nulls() {
943 let task_ctx = Arc::new(TaskContext::default());
944 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
945 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
946 None,
947 Some("a"),
948 Some("b"),
949 Some("d"),
950 Some("e"),
951 ]));
952 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
953 Some(8),
954 None,
955 Some(6),
956 None,
957 Some(4),
958 ]));
959 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
960
961 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));
962 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![
963 None,
964 Some("b"),
965 Some("g"),
966 Some("h"),
967 Some("i"),
968 ]));
969 let c: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![
970 Some(8),
971 None,
972 Some(5),
973 None,
974 Some(4),
975 ]));
976 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
977 let schema = b1.schema();
978
979 let sort = [
980 PhysicalSortExpr {
981 expr: col("b", &schema).unwrap(),
982 options: SortOptions {
983 descending: false,
984 nulls_first: true,
985 },
986 },
987 PhysicalSortExpr {
988 expr: col("c", &schema).unwrap(),
989 options: SortOptions {
990 descending: false,
991 nulls_first: false,
992 },
993 },
994 ]
995 .into();
996 let exec =
997 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
998 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
999
1000 let collected = collect(merge, task_ctx).await.unwrap();
1001 assert_eq!(collected.len(), 1);
1002
1003 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1004 +---+---+-------------------------------+
1005 | a | b | c |
1006 +---+---+-------------------------------+
1007 | 1 | | 1970-01-01T00:00:00.000000008 |
1008 | 1 | | 1970-01-01T00:00:00.000000008 |
1009 | 2 | a | |
1010 | 7 | b | 1970-01-01T00:00:00.000000006 |
1011 | 2 | b | |
1012 | 9 | d | |
1013 | 3 | e | 1970-01-01T00:00:00.000000004 |
1014 | 3 | g | 1970-01-01T00:00:00.000000005 |
1015 | 4 | h | |
1016 | 5 | i | 1970-01-01T00:00:00.000000004 |
1017 +---+---+-------------------------------+
1018 ");
1019 }
1020
1021 #[tokio::test]
1022 async fn test_sort_merge_single_partition_with_fetch() {
1023 let task_ctx = Arc::new(TaskContext::default());
1024 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1025 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1026 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1027 let schema = batch.schema();
1028
1029 let sort = [PhysicalSortExpr {
1030 expr: col("b", &schema).unwrap(),
1031 options: SortOptions {
1032 descending: false,
1033 nulls_first: true,
1034 },
1035 }]
1036 .into();
1037 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1038 let merge =
1039 Arc::new(SortPreservingMergeExec::new(sort, exec).with_fetch(Some(2)));
1040
1041 let collected = collect(merge, task_ctx).await.unwrap();
1042 assert_eq!(collected.len(), 1);
1043
1044 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1045 +---+---+
1046 | a | b |
1047 +---+---+
1048 | 1 | a |
1049 | 2 | b |
1050 +---+---+
1051 ");
1052 }
1053
1054 #[tokio::test]
1055 async fn test_sort_merge_single_partition_without_fetch() {
1056 let task_ctx = Arc::new(TaskContext::default());
1057 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 7, 9, 3]));
1058 let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"]));
1059 let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1060 let schema = batch.schema();
1061
1062 let sort = [PhysicalSortExpr {
1063 expr: col("b", &schema).unwrap(),
1064 options: SortOptions {
1065 descending: false,
1066 nulls_first: true,
1067 },
1068 }]
1069 .into();
1070 let exec = TestMemoryExec::try_new_exec(&[vec![batch]], schema, None).unwrap();
1071 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1072
1073 let collected = collect(merge, task_ctx).await.unwrap();
1074 assert_eq!(collected.len(), 1);
1075
1076 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1077 +---+---+
1078 | a | b |
1079 +---+---+
1080 | 1 | a |
1081 | 2 | b |
1082 | 7 | c |
1083 | 9 | d |
1084 | 3 | e |
1085 +---+---+
1086 ");
1087 }
1088
1089 #[tokio::test]
1090 async fn test_async() -> Result<()> {
1091 let task_ctx = Arc::new(TaskContext::default());
1092 let schema = make_partition(11).schema();
1093 let sort: LexOrdering = [PhysicalSortExpr {
1094 expr: col("i", &schema).unwrap(),
1095 options: SortOptions::default(),
1096 }]
1097 .into();
1098
1099 let batches =
1100 sorted_partitioned_input(sort.clone(), &[5, 7, 3], Arc::clone(&task_ctx))
1101 .await?;
1102
1103 let partition_count = batches.output_partitioning().partition_count();
1104 let mut streams = Vec::with_capacity(partition_count);
1105
1106 for partition in 0..partition_count {
1107 let mut builder = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1108
1109 let sender = builder.tx();
1110
1111 let mut stream = batches.execute(partition, Arc::clone(&task_ctx)).unwrap();
1112 builder.spawn(async move {
1113 while let Some(batch) = stream.next().await {
1114 sender.send(batch).await.unwrap();
1115 tokio::time::sleep(Duration::from_millis(10)).await;
1117 }
1118
1119 Ok(())
1120 });
1121
1122 streams.push(builder.build());
1123 }
1124
1125 let metrics = ExecutionPlanMetricsSet::new();
1126 let reservation =
1127 MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1128
1129 let fetch = None;
1130 let merge_stream = StreamingMergeBuilder::new()
1131 .with_streams(streams)
1132 .with_schema(batches.schema())
1133 .with_expressions(&sort)
1134 .with_metrics(BaselineMetrics::new(&metrics, 0))
1135 .with_batch_size(task_ctx.session_config().batch_size())
1136 .with_fetch(fetch)
1137 .with_reservation(reservation)
1138 .build()?;
1139
1140 let mut merged = common::collect(merge_stream).await.unwrap();
1141
1142 assert_eq!(merged.len(), 1);
1143 let merged = merged.remove(0);
1144 let basic = basic_sort(batches, sort.clone(), Arc::clone(&task_ctx)).await;
1145
1146 let basic = arrow::util::pretty::pretty_format_batches(&[basic])
1147 .unwrap()
1148 .to_string();
1149 let partition = arrow::util::pretty::pretty_format_batches(&[merged])
1150 .unwrap()
1151 .to_string();
1152
1153 assert_eq!(
1154 basic, partition,
1155 "basic:\n\n{basic}\n\npartition:\n\n{partition}\n\n"
1156 );
1157
1158 Ok(())
1159 }
1160
1161 #[tokio::test]
1162 async fn test_merge_metrics() {
1163 let task_ctx = Arc::new(TaskContext::default());
1164 let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
1165 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
1166 let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1167
1168 let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
1169 let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
1170 let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();
1171
1172 let schema = b1.schema();
1173 let sort = [PhysicalSortExpr {
1174 expr: col("b", &schema).unwrap(),
1175 options: Default::default(),
1176 }]
1177 .into();
1178 let exec =
1179 TestMemoryExec::try_new_exec(&[vec![b1], vec![b2]], schema, None).unwrap();
1180 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1181
1182 let collected = collect(Arc::clone(&merge) as Arc<dyn ExecutionPlan>, task_ctx)
1183 .await
1184 .unwrap();
1185 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1186 +----+---+
1187 | a | b |
1188 +----+---+
1189 | 1 | a |
1190 | 10 | b |
1191 | 2 | c |
1192 | 20 | d |
1193 +----+---+
1194 ");
1195
1196 let metrics = merge.metrics().unwrap();
1198
1199 assert_eq!(metrics.output_rows().unwrap(), 4);
1200 assert!(metrics.elapsed_compute().unwrap() > 0);
1201
1202 let mut saw_start = false;
1203 let mut saw_end = false;
1204 metrics.iter().for_each(|m| match m.value() {
1205 MetricValue::StartTimestamp(ts) => {
1206 saw_start = true;
1207 assert!(nanos_from_timestamp(ts) > 0);
1208 }
1209 MetricValue::EndTimestamp(ts) => {
1210 saw_end = true;
1211 assert!(nanos_from_timestamp(ts) > 0);
1212 }
1213 _ => {}
1214 });
1215
1216 assert!(saw_start);
1217 assert!(saw_end);
1218 }
1219
1220 fn nanos_from_timestamp(ts: &Timestamp) -> i64 {
1221 ts.value().unwrap().timestamp_nanos_opt().unwrap()
1222 }
1223
1224 #[tokio::test]
1225 async fn test_drop_cancel() -> Result<()> {
1226 let task_ctx = Arc::new(TaskContext::default());
1227 let schema =
1228 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
1229
1230 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
1231 let refs = blocking_exec.refs();
1232 let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
1233 [PhysicalSortExpr {
1234 expr: col("a", &schema)?,
1235 options: SortOptions::default(),
1236 }]
1237 .into(),
1238 blocking_exec,
1239 ));
1240
1241 let fut = collect(sort_preserving_merge_exec, task_ctx);
1242 let mut fut = fut.boxed();
1243
1244 assert_is_pending(&mut fut);
1245 drop(fut);
1246 assert_strong_count_converges_to_zero(refs).await;
1247
1248 Ok(())
1249 }
1250
1251 #[tokio::test]
1252 async fn test_stable_sort() {
1253 let task_ctx = Arc::new(TaskContext::default());
1254
1255 let partitions: Vec<Vec<RecordBatch>> = (0..10)
1263 .map(|batch_number| {
1264 let batch_number: Int32Array =
1265 vec![Some(batch_number), Some(batch_number)]
1266 .into_iter()
1267 .collect();
1268 let value: StringArray = vec![Some("A"), Some("B")].into_iter().collect();
1269
1270 let batch = RecordBatch::try_from_iter(vec![
1271 ("batch_number", Arc::new(batch_number) as ArrayRef),
1272 ("value", Arc::new(value) as ArrayRef),
1273 ])
1274 .unwrap();
1275
1276 vec![batch]
1277 })
1278 .collect();
1279
1280 let schema = partitions[0][0].schema();
1281
1282 let sort = [PhysicalSortExpr {
1283 expr: col("value", &schema).unwrap(),
1284 options: SortOptions {
1285 descending: false,
1286 nulls_first: true,
1287 },
1288 }]
1289 .into();
1290
1291 let exec = TestMemoryExec::try_new_exec(&partitions, schema, None).unwrap();
1292 let merge = Arc::new(SortPreservingMergeExec::new(sort, exec));
1293
1294 let collected = collect(merge, task_ctx).await.unwrap();
1295 assert_eq!(collected.len(), 1);
1296
1297 assert_snapshot!(batches_to_string(collected.as_slice()), @r"
1301 +--------------+-------+
1302 | batch_number | value |
1303 +--------------+-------+
1304 | 0 | A |
1305 | 1 | A |
1306 | 2 | A |
1307 | 3 | A |
1308 | 4 | A |
1309 | 5 | A |
1310 | 6 | A |
1311 | 7 | A |
1312 | 8 | A |
1313 | 9 | A |
1314 | 0 | B |
1315 | 1 | B |
1316 | 2 | B |
1317 | 3 | B |
1318 | 4 | B |
1319 | 5 | B |
1320 | 6 | B |
1321 | 7 | B |
1322 | 8 | B |
1323 | 9 | B |
1324 +--------------+-------+
1325 ");
1326 }
1327
1328 #[derive(Debug)]
1329 struct CongestionState {
1330 wakers: Vec<Waker>,
1331 unpolled_partitions: HashSet<usize>,
1332 }
1333
1334 #[derive(Debug)]
1335 struct Congestion {
1336 congestion_state: Mutex<CongestionState>,
1337 }
1338
1339 impl Congestion {
1340 fn new(partition_count: usize) -> Self {
1341 Congestion {
1342 congestion_state: Mutex::new(CongestionState {
1343 wakers: vec![],
1344 unpolled_partitions: (0usize..partition_count).collect(),
1345 }),
1346 }
1347 }
1348
1349 fn check_congested(&self, partition: usize, cx: &mut Context<'_>) -> Poll<()> {
1350 let mut state = self.congestion_state.lock().unwrap();
1351
1352 state.unpolled_partitions.remove(&partition);
1353
1354 if state.unpolled_partitions.is_empty() {
1355 state.wakers.iter().for_each(|w| w.wake_by_ref());
1356 state.wakers.clear();
1357 Poll::Ready(())
1358 } else {
1359 state.wakers.push(cx.waker().clone());
1360 Poll::Pending
1361 }
1362 }
1363 }
1364
1365 #[derive(Debug, Clone)]
1368 struct CongestedExec {
1369 schema: Schema,
1370 cache: Arc<PlanProperties>,
1371 congestion: Arc<Congestion>,
1372 }
1373
1374 impl CongestedExec {
1375 fn compute_properties(schema: SchemaRef) -> PlanProperties {
1376 let columns = schema
1377 .fields
1378 .iter()
1379 .enumerate()
1380 .map(|(i, f)| Arc::new(Column::new(f.name(), i)) as Arc<dyn PhysicalExpr>)
1381 .collect::<Vec<_>>();
1382 let mut eq_properties = EquivalenceProperties::new(schema);
1383 eq_properties.add_ordering(
1384 columns
1385 .iter()
1386 .map(|expr| PhysicalSortExpr::new_default(Arc::clone(expr))),
1387 );
1388 PlanProperties::new(
1389 eq_properties,
1390 Partitioning::Hash(columns, 3),
1391 EmissionType::Incremental,
1392 Boundedness::Unbounded {
1393 requires_infinite_memory: false,
1394 },
1395 )
1396 }
1397 }
1398
1399 impl ExecutionPlan for CongestedExec {
1400 fn name(&self) -> &'static str {
1401 Self::static_name()
1402 }
1403 fn properties(&self) -> &Arc<PlanProperties> {
1404 &self.cache
1405 }
1406 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1407 vec![]
1408 }
1409 fn with_new_children(
1410 self: Arc<Self>,
1411 _: Vec<Arc<dyn ExecutionPlan>>,
1412 ) -> Result<Arc<dyn ExecutionPlan>> {
1413 Ok(self)
1414 }
1415 fn execute(
1416 &self,
1417 partition: usize,
1418 _context: Arc<TaskContext>,
1419 ) -> Result<SendableRecordBatchStream> {
1420 Ok(Box::pin(CongestedStream {
1421 schema: Arc::new(self.schema.clone()),
1422 none_polled_once: false,
1423 congestion: Arc::clone(&self.congestion),
1424 partition,
1425 }))
1426 }
1427 }
1428
1429 impl DisplayAs for CongestedExec {
1430 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
1431 match t {
1432 DisplayFormatType::Default | DisplayFormatType::Verbose => {
1433 write!(f, "CongestedExec",).unwrap()
1434 }
1435 DisplayFormatType::TreeRender => {
1436 write!(f, "").unwrap()
1438 }
1439 }
1440 Ok(())
1441 }
1442 }
1443
1444 #[derive(Debug)]
1447 pub struct CongestedStream {
1448 schema: SchemaRef,
1449 none_polled_once: bool,
1450 congestion: Arc<Congestion>,
1451 partition: usize,
1452 }
1453
1454 impl Stream for CongestedStream {
1455 type Item = Result<RecordBatch>;
1456 fn poll_next(
1457 mut self: Pin<&mut Self>,
1458 cx: &mut Context<'_>,
1459 ) -> Poll<Option<Self::Item>> {
1460 match self.partition {
1461 0 => {
1462 let _ = self.congestion.check_congested(self.partition, cx);
1463 if self.none_polled_once {
1464 panic!("Exhausted stream is polled more than once")
1465 } else {
1466 self.none_polled_once = true;
1467 Poll::Ready(None)
1468 }
1469 }
1470 _ => {
1471 ready!(self.congestion.check_congested(self.partition, cx));
1472 Poll::Ready(None)
1473 }
1474 }
1475 }
1476 }
1477
1478 impl RecordBatchStream for CongestedStream {
1479 fn schema(&self) -> SchemaRef {
1480 Arc::clone(&self.schema)
1481 }
1482 }
1483
1484 #[tokio::test]
1485 async fn test_spm_congestion() -> Result<()> {
1486 let task_ctx = Arc::new(TaskContext::default());
1487 let schema = Schema::new(vec![Field::new("c1", DataType::UInt64, false)]);
1488 let properties = CongestedExec::compute_properties(Arc::new(schema.clone()));
1489 let &partition_count = match properties.output_partitioning() {
1490 Partitioning::RoundRobinBatch(partitions) => partitions,
1491 Partitioning::Hash(_, partitions) => partitions,
1492 Partitioning::UnknownPartitioning(partitions) => partitions,
1493 };
1494 let source = CongestedExec {
1495 schema: schema.clone(),
1496 cache: Arc::new(properties),
1497 congestion: Arc::new(Congestion::new(partition_count)),
1498 };
1499 let spm = SortPreservingMergeExec::new(
1500 [PhysicalSortExpr::new_default(Arc::new(Column::new(
1501 "c1", 0,
1502 )))]
1503 .into(),
1504 Arc::new(source),
1505 );
1506 let spm_task = SpawnedTask::spawn(collect(Arc::new(spm), task_ctx));
1507
1508 let result = timeout(Duration::from_secs(3), spm_task.join()).await;
1509 match result {
1510 Ok(Ok(Ok(_batches))) => Ok(()),
1511 Ok(Ok(Err(e))) => Err(e),
1512 Ok(Err(_)) => exec_err!("SortPreservingMerge task panicked or was cancelled"),
1513 Err(_) => exec_err!("SortPreservingMerge caused a deadlock"),
1514 }
1515 }
1516
1517 #[tokio::test]
1518 async fn test_sort_merge_stops_after_error_with_buffered_rows() -> Result<()> {
1519 let task_ctx = Arc::new(TaskContext::default());
1520 let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, false)]));
1521 let sort: LexOrdering = [PhysicalSortExpr::new_default(Arc::new(Column::new(
1522 "i", 0,
1523 ))
1524 as Arc<dyn PhysicalExpr>)]
1525 .into();
1526
1527 let mut stream0 = RecordBatchReceiverStream::builder(Arc::clone(&schema), 2);
1528 let tx0 = stream0.tx();
1529 let schema0 = Arc::clone(&schema);
1530 stream0.spawn(async move {
1531 let batch =
1532 RecordBatch::try_new(schema0, vec![Arc::new(Int32Array::from(vec![1]))])?;
1533 tx0.send(Ok(batch)).await.unwrap();
1534 tx0.send(exec_err!("stream failure")).await.unwrap();
1535 Ok(())
1536 });
1537
1538 let mut stream1 = RecordBatchReceiverStream::builder(Arc::clone(&schema), 1);
1539 let tx1 = stream1.tx();
1540 let schema1 = Arc::clone(&schema);
1541 stream1.spawn(async move {
1542 let batch =
1543 RecordBatch::try_new(schema1, vec![Arc::new(Int32Array::from(vec![2]))])?;
1544 tx1.send(Ok(batch)).await.unwrap();
1545 Ok(())
1546 });
1547
1548 let metrics = ExecutionPlanMetricsSet::new();
1549 let reservation =
1550 MemoryConsumer::new("test").register(&task_ctx.runtime_env().memory_pool);
1551
1552 let mut merge_stream = StreamingMergeBuilder::new()
1553 .with_streams(vec![stream0.build(), stream1.build()])
1554 .with_schema(Arc::clone(&schema))
1555 .with_expressions(&sort)
1556 .with_metrics(BaselineMetrics::new(&metrics, 0))
1557 .with_batch_size(task_ctx.session_config().batch_size())
1558 .with_fetch(None)
1559 .with_reservation(reservation)
1560 .build()?;
1561
1562 let first = merge_stream.next().await.unwrap();
1563 assert!(first.is_err(), "expected merge stream to surface the error");
1564 assert!(
1565 merge_stream.next().await.is_none(),
1566 "merge stream yielded data after returning an error"
1567 );
1568
1569 Ok(())
1570 }
1571}