datafusion_physical_plan/
coalesce_partitions.rs1use std::any::Any;
22use std::sync::Arc;
23
24use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
25use super::stream::{ObservedStream, RecordBatchReceiverStream};
26use super::{
27 DisplayAs, ExecutionPlanProperties, PlanProperties, SendableRecordBatchStream,
28 Statistics,
29};
30use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType};
31use crate::filter_pushdown::{FilterDescription, FilterPushdownPhase};
32use crate::projection::{make_with_child, ProjectionExec};
33use crate::{DisplayFormatType, ExecutionPlan, Partitioning};
34
35use datafusion_common::config::ConfigOptions;
36use datafusion_common::{internal_err, Result};
37use datafusion_execution::TaskContext;
38use datafusion_physical_expr::PhysicalExpr;
39
40#[derive(Debug, Clone)]
43pub struct CoalescePartitionsExec {
44 input: Arc<dyn ExecutionPlan>,
46 metrics: ExecutionPlanMetricsSet,
48 cache: PlanProperties,
49 pub(crate) fetch: Option<usize>,
51}
52
53impl CoalescePartitionsExec {
54 pub fn new(input: Arc<dyn ExecutionPlan>) -> Self {
56 let cache = Self::compute_properties(&input);
57 CoalescePartitionsExec {
58 input,
59 metrics: ExecutionPlanMetricsSet::new(),
60 cache,
61 fetch: None,
62 }
63 }
64
65 pub fn with_fetch(mut self, fetch: Option<usize>) -> Self {
67 self.fetch = fetch;
68 self
69 }
70
71 pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
73 &self.input
74 }
75
76 fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties {
78 let input_partitions = input.output_partitioning().partition_count();
79 let (drive, scheduling) = if input_partitions > 1 {
80 (EvaluationType::Eager, SchedulingType::Cooperative)
81 } else {
82 (
83 input.properties().evaluation_type,
84 input.properties().scheduling_type,
85 )
86 };
87
88 let mut eq_properties = input.equivalence_properties().clone();
90 eq_properties.clear_orderings();
91 eq_properties.clear_per_partition_constants();
92 PlanProperties::new(
93 eq_properties, Partitioning::UnknownPartitioning(1), input.pipeline_behavior(),
96 input.boundedness(),
97 )
98 .with_evaluation_type(drive)
99 .with_scheduling_type(scheduling)
100 }
101}
102
103impl DisplayAs for CoalescePartitionsExec {
104 fn fmt_as(
105 &self,
106 t: DisplayFormatType,
107 f: &mut std::fmt::Formatter,
108 ) -> std::fmt::Result {
109 match t {
110 DisplayFormatType::Default | DisplayFormatType::Verbose => match self.fetch {
111 Some(fetch) => {
112 write!(f, "CoalescePartitionsExec: fetch={fetch}")
113 }
114 None => write!(f, "CoalescePartitionsExec"),
115 },
116 DisplayFormatType::TreeRender => match self.fetch {
117 Some(fetch) => {
118 write!(f, "limit: {fetch}")
119 }
120 None => write!(f, ""),
121 },
122 }
123 }
124}
125
126impl ExecutionPlan for CoalescePartitionsExec {
127 fn name(&self) -> &'static str {
128 "CoalescePartitionsExec"
129 }
130
131 fn as_any(&self) -> &dyn Any {
133 self
134 }
135
136 fn properties(&self) -> &PlanProperties {
137 &self.cache
138 }
139
140 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
141 vec![&self.input]
142 }
143
144 fn benefits_from_input_partitioning(&self) -> Vec<bool> {
145 vec![false]
146 }
147
148 fn with_new_children(
149 self: Arc<Self>,
150 children: Vec<Arc<dyn ExecutionPlan>>,
151 ) -> Result<Arc<dyn ExecutionPlan>> {
152 let mut plan = CoalescePartitionsExec::new(Arc::clone(&children[0]));
153 plan.fetch = self.fetch;
154 Ok(Arc::new(plan))
155 }
156
157 fn execute(
158 &self,
159 partition: usize,
160 context: Arc<TaskContext>,
161 ) -> Result<SendableRecordBatchStream> {
162 if 0 != partition {
164 return internal_err!("CoalescePartitionsExec invalid partition {partition}");
165 }
166
167 let input_partitions = self.input.output_partitioning().partition_count();
168 match input_partitions {
169 0 => internal_err!(
170 "CoalescePartitionsExec requires at least one input partition"
171 ),
172 1 => {
173 let child_stream = self.input.execute(0, context)?;
176 if self.fetch.is_some() {
177 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
178 return Ok(Box::pin(ObservedStream::new(
179 child_stream,
180 baseline_metrics,
181 self.fetch,
182 )));
183 }
184 Ok(child_stream)
185 }
186 _ => {
187 let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
188 let elapsed_compute = baseline_metrics.elapsed_compute().clone();
191 let _timer = elapsed_compute.timer();
192
193 let mut builder =
197 RecordBatchReceiverStream::builder(self.schema(), input_partitions);
198
199 for part_i in 0..input_partitions {
202 builder.run_input(
203 Arc::clone(&self.input),
204 part_i,
205 Arc::clone(&context),
206 );
207 }
208
209 let stream = builder.build();
210 Ok(Box::pin(ObservedStream::new(
211 stream,
212 baseline_metrics,
213 self.fetch,
214 )))
215 }
216 }
217 }
218
219 fn metrics(&self) -> Option<MetricsSet> {
220 Some(self.metrics.clone_inner())
221 }
222
223 fn statistics(&self) -> Result<Statistics> {
224 self.partition_statistics(None)
225 }
226
227 fn partition_statistics(&self, _partition: Option<usize>) -> Result<Statistics> {
228 self.input
229 .partition_statistics(None)?
230 .with_fetch(self.fetch, 0, 1)
231 }
232
233 fn supports_limit_pushdown(&self) -> bool {
234 true
235 }
236
237 fn cardinality_effect(&self) -> CardinalityEffect {
238 CardinalityEffect::Equal
239 }
240
241 fn try_swapping_with_projection(
245 &self,
246 projection: &ProjectionExec,
247 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
248 if projection.expr().len() >= projection.input().schema().fields().len() {
250 return Ok(None);
251 }
252 make_with_child(projection, projection.input().children()[0]).map(|e| {
254 if self.fetch.is_some() {
255 let mut plan = CoalescePartitionsExec::new(e);
256 plan.fetch = self.fetch;
257 Some(Arc::new(plan) as _)
258 } else {
259 Some(Arc::new(CoalescePartitionsExec::new(e)) as _)
260 }
261 })
262 }
263
264 fn fetch(&self) -> Option<usize> {
265 self.fetch
266 }
267
268 fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
269 Some(Arc::new(CoalescePartitionsExec {
270 input: Arc::clone(&self.input),
271 fetch: limit,
272 metrics: self.metrics.clone(),
273 cache: self.cache.clone(),
274 }))
275 }
276
277 fn gather_filters_for_pushdown(
278 &self,
279 _phase: FilterPushdownPhase,
280 parent_filters: Vec<Arc<dyn PhysicalExpr>>,
281 _config: &ConfigOptions,
282 ) -> Result<FilterDescription> {
283 FilterDescription::from_children(parent_filters, &self.children())
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::test::exec::{
291 assert_strong_count_converges_to_zero, BlockingExec, PanicExec,
292 };
293 use crate::test::{self, assert_is_pending};
294 use crate::{collect, common};
295
296 use arrow::datatypes::{DataType, Field, Schema};
297
298 use futures::FutureExt;
299
300 #[tokio::test]
301 async fn merge() -> Result<()> {
302 let task_ctx = Arc::new(TaskContext::default());
303
304 let num_partitions = 4;
305 let csv = test::scan_partitioned(num_partitions);
306
307 assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
309
310 let merge = CoalescePartitionsExec::new(csv);
311
312 assert_eq!(
314 merge.properties().output_partitioning().partition_count(),
315 1
316 );
317
318 let iter = merge.execute(0, task_ctx)?;
320 let batches = common::collect(iter).await?;
321 assert_eq!(batches.len(), num_partitions);
322
323 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
325 assert_eq!(row_count, 400);
326
327 Ok(())
328 }
329
330 #[tokio::test]
331 async fn test_drop_cancel() -> Result<()> {
332 let task_ctx = Arc::new(TaskContext::default());
333 let schema =
334 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
335
336 let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 2));
337 let refs = blocking_exec.refs();
338 let coalesce_partitions_exec =
339 Arc::new(CoalescePartitionsExec::new(blocking_exec));
340
341 let fut = collect(coalesce_partitions_exec, task_ctx);
342 let mut fut = fut.boxed();
343
344 assert_is_pending(&mut fut);
345 drop(fut);
346 assert_strong_count_converges_to_zero(refs).await;
347
348 Ok(())
349 }
350
351 #[tokio::test]
352 #[should_panic(expected = "PanickingStream did panic")]
353 async fn test_panic() {
354 let task_ctx = Arc::new(TaskContext::default());
355 let schema =
356 Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)]));
357
358 let panicking_exec = Arc::new(PanicExec::new(Arc::clone(&schema), 2));
359 let coalesce_partitions_exec =
360 Arc::new(CoalescePartitionsExec::new(panicking_exec));
361
362 collect(coalesce_partitions_exec, task_ctx).await.unwrap();
363 }
364
365 #[tokio::test]
366 async fn test_single_partition_with_fetch() -> Result<()> {
367 let task_ctx = Arc::new(TaskContext::default());
368
369 let input = test::scan_partitioned(1);
371
372 let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(3));
374
375 let stream = coalesce.execute(0, task_ctx)?;
376 let batches = common::collect(stream).await?;
377
378 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
379 assert_eq!(row_count, 3, "Should only return 3 rows due to fetch=3");
380
381 Ok(())
382 }
383
384 #[tokio::test]
385 async fn test_multi_partition_with_fetch_one() -> Result<()> {
386 let task_ctx = Arc::new(TaskContext::default());
387
388 let input = test::scan_partitioned(4);
391
392 let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(1));
394
395 let stream = coalesce.execute(0, task_ctx)?;
396 let batches = common::collect(stream).await?;
397
398 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
399 assert_eq!(
400 row_count, 1,
401 "Should only return 1 row due to fetch=1, not one per partition"
402 );
403
404 Ok(())
405 }
406
407 #[tokio::test]
408 async fn test_single_partition_without_fetch() -> Result<()> {
409 let task_ctx = Arc::new(TaskContext::default());
410
411 let input = test::scan_partitioned(1);
413
414 let coalesce = CoalescePartitionsExec::new(input);
416
417 let stream = coalesce.execute(0, task_ctx)?;
418 let batches = common::collect(stream).await?;
419
420 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
421 assert_eq!(
422 row_count, 100,
423 "Should return all 100 rows when fetch is None"
424 );
425
426 Ok(())
427 }
428
429 #[tokio::test]
430 async fn test_single_partition_fetch_larger_than_batch() -> Result<()> {
431 let task_ctx = Arc::new(TaskContext::default());
432
433 let input = test::scan_partitioned(1);
435
436 let coalesce = CoalescePartitionsExec::new(input).with_fetch(Some(200));
438
439 let stream = coalesce.execute(0, task_ctx)?;
440 let batches = common::collect(stream).await?;
441
442 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
443 assert_eq!(
444 row_count, 100,
445 "Should return all available rows (100) when fetch (200) is larger"
446 );
447
448 Ok(())
449 }
450
451 #[tokio::test]
452 async fn test_multi_partition_fetch_exact_match() -> Result<()> {
453 let task_ctx = Arc::new(TaskContext::default());
454
455 let num_partitions = 4;
457 let csv = test::scan_partitioned(num_partitions);
458
459 let coalesce = CoalescePartitionsExec::new(csv).with_fetch(Some(400));
461
462 let stream = coalesce.execute(0, task_ctx)?;
463 let batches = common::collect(stream).await?;
464
465 let row_count: usize = batches.iter().map(|batch| batch.num_rows()).sum();
466 assert_eq!(row_count, 400, "Should return exactly 400 rows");
467
468 Ok(())
469 }
470}