1use std::fmt::Formatter;
23use std::sync::Arc;
24
25use super::bitwise_stream::BitwiseSortMergeJoinStream;
26use super::materializing_stream::MaterializingSortMergeJoinStream;
27use super::metrics::SortMergeJoinMetrics;
28use crate::execution_plan::{EmissionType, boundedness_from_children};
29use crate::expressions::PhysicalSortExpr;
30use crate::joins::utils::{
31 JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid,
32 estimate_join_statistics, reorder_output_after_swap,
33 symmetric_join_output_partitioning,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet, SpillMetrics};
36use crate::projection::{
37 ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
38 physical_to_column_exprs, update_join_on,
39};
40use crate::spill::spill_manager::SpillManager;
41use crate::{
42 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
43 PlanProperties, SendableRecordBatchStream, Statistics, check_if_same_properties,
44};
45
46use arrow::compute::SortOptions;
47use arrow::datatypes::SchemaRef;
48use datafusion_common::{
49 JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err,
50 plan_err,
51};
52use datafusion_execution::TaskContext;
53use datafusion_execution::memory_pool::MemoryConsumer;
54use datafusion_physical_expr::equivalence::join_equivalence_properties;
55use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
56use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
57
58#[derive(Debug, Clone)]
107pub struct SortMergeJoinExec {
108 pub left: Arc<dyn ExecutionPlan>,
110 pub right: Arc<dyn ExecutionPlan>,
112 pub on: JoinOn,
114 pub filter: Option<JoinFilter>,
116 pub join_type: JoinType,
118 schema: SchemaRef,
120 metrics: ExecutionPlanMetricsSet,
122 left_sort_exprs: LexOrdering,
124 right_sort_exprs: LexOrdering,
126 pub sort_options: Vec<SortOptions>,
128 pub null_equality: NullEquality,
130 cache: Arc<PlanProperties>,
132}
133
134impl SortMergeJoinExec {
135 pub fn try_new(
140 left: Arc<dyn ExecutionPlan>,
141 right: Arc<dyn ExecutionPlan>,
142 on: JoinOn,
143 filter: Option<JoinFilter>,
144 join_type: JoinType,
145 sort_options: Vec<SortOptions>,
146 null_equality: NullEquality,
147 ) -> Result<Self> {
148 let left_schema = left.schema();
149 let right_schema = right.schema();
150
151 check_join_is_valid(&left_schema, &right_schema, &on)?;
152 if sort_options.len() != on.len() {
153 return plan_err!(
154 "Expected number of sort options: {}, actual: {}",
155 on.len(),
156 sort_options.len()
157 );
158 }
159
160 let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
161 .iter()
162 .zip(sort_options.iter())
163 .map(|((l, r), sort_op)| {
164 let left = PhysicalSortExpr {
165 expr: Arc::clone(l),
166 options: *sort_op,
167 };
168 let right = PhysicalSortExpr {
169 expr: Arc::clone(r),
170 options: *sort_op,
171 };
172 (left, right)
173 })
174 .unzip();
175 let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
176 return plan_err!(
177 "SortMergeJoinExec requires valid sort expressions for its left side"
178 );
179 };
180 let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
181 return plan_err!(
182 "SortMergeJoinExec requires valid sort expressions for its right side"
183 );
184 };
185
186 let schema =
187 Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
188 let cache =
189 Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
190 Ok(Self {
191 left,
192 right,
193 on,
194 filter,
195 join_type,
196 schema,
197 metrics: ExecutionPlanMetricsSet::new(),
198 left_sort_exprs,
199 right_sort_exprs,
200 sort_options,
201 null_equality,
202 cache: Arc::new(cache),
203 })
204 }
205
206 pub fn probe_side(join_type: &JoinType) -> JoinSide {
209 match join_type {
212 JoinType::Right
214 | JoinType::RightSemi
215 | JoinType::RightAnti
216 | JoinType::RightMark => JoinSide::Right,
217 JoinType::Inner
218 | JoinType::Left
219 | JoinType::Full
220 | JoinType::LeftAnti
221 | JoinType::LeftSemi
222 | JoinType::LeftMark => JoinSide::Left,
223 }
224 }
225
226 fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
228 match join_type {
229 JoinType::Inner => vec![true, false],
230 JoinType::Left
231 | JoinType::LeftSemi
232 | JoinType::LeftAnti
233 | JoinType::LeftMark => vec![true, false],
234 JoinType::Right
235 | JoinType::RightSemi
236 | JoinType::RightAnti
237 | JoinType::RightMark => {
238 vec![false, true]
239 }
240 _ => vec![false, false],
241 }
242 }
243
244 pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
246 &self.on
247 }
248
249 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
251 &self.right
252 }
253
254 pub fn join_type(&self) -> JoinType {
256 self.join_type
257 }
258
259 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
261 &self.left
262 }
263
264 pub fn filter(&self) -> &Option<JoinFilter> {
266 &self.filter
267 }
268
269 pub fn sort_options(&self) -> &[SortOptions] {
271 &self.sort_options
272 }
273
274 pub fn null_equality(&self) -> NullEquality {
276 self.null_equality
277 }
278
279 fn compute_properties(
281 left: &Arc<dyn ExecutionPlan>,
282 right: &Arc<dyn ExecutionPlan>,
283 schema: SchemaRef,
284 join_type: JoinType,
285 join_on: JoinOnRef,
286 ) -> Result<PlanProperties> {
287 let eq_properties = join_equivalence_properties(
289 left.equivalence_properties().clone(),
290 right.equivalence_properties().clone(),
291 &join_type,
292 schema,
293 &Self::maintains_input_order(join_type),
294 Some(Self::probe_side(&join_type)),
295 join_on,
296 )?;
297
298 let output_partitioning =
299 symmetric_join_output_partitioning(left, right, &join_type)?;
300
301 Ok(PlanProperties::new(
302 eq_properties,
303 output_partitioning,
304 EmissionType::Incremental,
305 boundedness_from_children([left, right]),
306 ))
307 }
308
309 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
315 let left = self.left();
316 let right = self.right();
317 let new_join = SortMergeJoinExec::try_new(
318 Arc::clone(right),
319 Arc::clone(left),
320 self.on()
321 .iter()
322 .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
323 .collect::<Vec<_>>(),
324 self.filter().as_ref().map(JoinFilter::swap),
325 self.join_type().swap(),
326 self.sort_options.clone(),
327 self.null_equality,
328 )?;
329
330 if matches!(
333 self.join_type(),
334 JoinType::LeftSemi
335 | JoinType::RightSemi
336 | JoinType::LeftAnti
337 | JoinType::RightAnti
338 | JoinType::LeftMark
339 | JoinType::RightMark
340 ) {
341 Ok(Arc::new(new_join))
342 } else {
343 reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
344 }
345 }
346
347 fn with_new_children_and_same_properties(
348 &self,
349 mut children: Vec<Arc<dyn ExecutionPlan>>,
350 ) -> Self {
351 let left = children.swap_remove(0);
352 let right = children.swap_remove(0);
353 Self {
354 left,
355 right,
356 metrics: ExecutionPlanMetricsSet::new(),
357 ..Self::clone(self)
358 }
359 }
360}
361
362impl DisplayAs for SortMergeJoinExec {
363 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
364 match t {
365 DisplayFormatType::Default | DisplayFormatType::Verbose => {
366 let on = self
367 .on
368 .iter()
369 .map(|(c1, c2)| format!("({c1}, {c2})"))
370 .collect::<Vec<String>>()
371 .join(", ");
372 let display_null_equality =
373 if self.null_equality() == NullEquality::NullEqualsNull {
374 ", NullsEqual: true"
375 } else {
376 ""
377 };
378 write!(
379 f,
380 "{}: join_type={:?}, on=[{}]{}{}",
381 Self::static_name(),
382 self.join_type,
383 on,
384 self.filter.as_ref().map_or_else(
385 || "".to_string(),
386 |f| format!(", filter={}", f.expression())
387 ),
388 display_null_equality,
389 )
390 }
391 DisplayFormatType::TreeRender => {
392 let on = self
393 .on
394 .iter()
395 .map(|(c1, c2)| {
396 format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
397 })
398 .collect::<Vec<String>>()
399 .join(", ");
400
401 if self.join_type() != JoinType::Inner {
402 writeln!(f, "join_type={:?}", self.join_type)?;
403 }
404 writeln!(f, "on={on}")?;
405
406 if self.null_equality() == NullEquality::NullEqualsNull {
407 writeln!(f, "NullsEqual: true")?;
408 }
409
410 Ok(())
411 }
412 }
413 }
414}
415
416impl ExecutionPlan for SortMergeJoinExec {
417 fn name(&self) -> &'static str {
418 "SortMergeJoinExec"
419 }
420
421 fn properties(&self) -> &Arc<PlanProperties> {
422 &self.cache
423 }
424
425 fn required_input_distribution(&self) -> Vec<Distribution> {
426 let (left_expr, right_expr) = self
427 .on
428 .iter()
429 .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
430 .unzip();
431 vec![
432 Distribution::HashPartitioned(left_expr),
433 Distribution::HashPartitioned(right_expr),
434 ]
435 }
436
437 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
438 vec![
439 Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
440 Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
441 ]
442 }
443
444 fn maintains_input_order(&self) -> Vec<bool> {
445 Self::maintains_input_order(self.join_type)
446 }
447
448 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
449 vec![&self.left, &self.right]
450 }
451
452 fn with_new_children(
453 self: Arc<Self>,
454 children: Vec<Arc<dyn ExecutionPlan>>,
455 ) -> Result<Arc<dyn ExecutionPlan>> {
456 check_if_same_properties!(self, children);
457 match &children[..] {
458 [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
459 Arc::clone(left),
460 Arc::clone(right),
461 self.on.clone(),
462 self.filter.clone(),
463 self.join_type,
464 self.sort_options.clone(),
465 self.null_equality,
466 )?)),
467 _ => internal_err!("SortMergeJoin wrong number of children"),
468 }
469 }
470
471 fn execute(
472 &self,
473 partition: usize,
474 context: Arc<TaskContext>,
475 ) -> Result<SendableRecordBatchStream> {
476 let left_partitions = self.left.output_partitioning().partition_count();
477 let right_partitions = self.right.output_partitioning().partition_count();
478 assert_eq_or_internal_err!(
479 left_partitions,
480 right_partitions,
481 "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
482 consider using RepartitionExec"
483 );
484 let (on_left, on_right) = self.on.iter().cloned().unzip();
485 let (streamed, buffered, on_streamed, on_buffered) =
486 if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
487 (
488 Arc::clone(&self.left),
489 Arc::clone(&self.right),
490 on_left,
491 on_right,
492 )
493 } else {
494 (
495 Arc::clone(&self.right),
496 Arc::clone(&self.left),
497 on_right,
498 on_left,
499 )
500 };
501
502 let streamed = streamed.execute(partition, Arc::clone(&context))?;
504 let buffered = buffered.execute(partition, Arc::clone(&context))?;
505
506 let batch_size = context.session_config().batch_size();
507 let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
508 .register(context.memory_pool());
509 let spill_manager = SpillManager::new(
510 context.runtime_env(),
511 SpillMetrics::new(&self.metrics, partition),
512 buffered.schema(),
513 )
514 .with_compression_type(context.session_config().spill_compression());
515
516 if matches!(
517 self.join_type,
518 JoinType::LeftSemi
519 | JoinType::LeftAnti
520 | JoinType::RightSemi
521 | JoinType::RightAnti
522 | JoinType::LeftMark
523 | JoinType::RightMark
524 ) {
525 Ok(Box::pin(BitwiseSortMergeJoinStream::try_new(
526 Arc::clone(&self.schema),
527 self.sort_options.clone(),
528 self.null_equality,
529 streamed,
530 buffered,
531 on_streamed,
532 on_buffered,
533 self.filter.clone(),
534 self.join_type,
535 batch_size,
536 partition,
537 &self.metrics,
538 reservation,
539 spill_manager,
540 context.runtime_env(),
541 )?))
542 } else {
543 Ok(Box::pin(MaterializingSortMergeJoinStream::try_new(
544 Arc::clone(&self.schema),
545 self.sort_options.clone(),
546 self.null_equality,
547 streamed,
548 buffered,
549 on_streamed,
550 on_buffered,
551 self.filter.clone(),
552 self.join_type,
553 batch_size,
554 SortMergeJoinMetrics::new(partition, &self.metrics),
555 reservation,
556 spill_manager,
557 context.runtime_env(),
558 )?))
559 }
560 }
561
562 fn metrics(&self) -> Option<MetricsSet> {
563 Some(self.metrics.clone_inner())
564 }
565
566 fn partition_statistics(&self, partition: Option<usize>) -> Result<Arc<Statistics>> {
567 let left_stats = Arc::unwrap_or_clone(self.left.partition_statistics(partition)?);
579 let right_stats =
580 Arc::unwrap_or_clone(self.right.partition_statistics(partition)?);
581 Ok(Arc::new(estimate_join_statistics(
582 left_stats,
583 right_stats,
584 &self.on,
585 &self.join_type,
586 &self.schema,
587 )?))
588 }
589
590 fn try_swapping_with_projection(
594 &self,
595 projection: &ProjectionExec,
596 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
597 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
599 else {
600 return Ok(None);
601 };
602
603 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
604 self.left().schema().fields().len(),
605 &projection_as_columns,
606 );
607
608 if !join_allows_pushdown(
609 &projection_as_columns,
610 &self.schema(),
611 far_right_left_col_ind,
612 far_left_right_col_ind,
613 ) {
614 return Ok(None);
615 }
616
617 let Some(new_on) = update_join_on(
618 &projection_as_columns[0..=far_right_left_col_ind as _],
619 &projection_as_columns[far_left_right_col_ind as _..],
620 self.on(),
621 self.left().schema().fields().len(),
622 ) else {
623 return Ok(None);
624 };
625
626 let (new_left, new_right) = new_join_children(
627 &projection_as_columns,
628 far_right_left_col_ind,
629 far_left_right_col_ind,
630 self.children()[0],
631 self.children()[1],
632 )?;
633
634 Ok(Some(Arc::new(SortMergeJoinExec::try_new(
635 Arc::new(new_left),
636 Arc::new(new_right),
637 new_on,
638 self.filter.clone(),
639 self.join_type,
640 self.sort_options.clone(),
641 self.null_equality,
642 )?)))
643 }
644}