1use std::any::Any;
23use std::fmt::Formatter;
24use std::sync::Arc;
25
26use crate::execution_plan::{EmissionType, boundedness_from_children};
27use crate::expressions::PhysicalSortExpr;
28use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
29use crate::joins::sort_merge_join::stream::SortMergeJoinStream;
30use crate::joins::utils::{
31 JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid,
32 estimate_join_statistics, reorder_output_after_swap,
33 symmetric_join_output_partitioning,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
36use crate::projection::{
37 ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
38 physical_to_column_exprs, update_join_on,
39};
40use crate::{
41 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
42 PlanProperties, SendableRecordBatchStream, Statistics, check_if_same_properties,
43};
44
45use arrow::compute::SortOptions;
46use arrow::datatypes::SchemaRef;
47use datafusion_common::{
48 JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err,
49 plan_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_execution::memory_pool::MemoryConsumer;
53use datafusion_physical_expr::equivalence::join_equivalence_properties;
54use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
55use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
56
57#[derive(Debug, Clone)]
106pub struct SortMergeJoinExec {
107 pub left: Arc<dyn ExecutionPlan>,
109 pub right: Arc<dyn ExecutionPlan>,
111 pub on: JoinOn,
113 pub filter: Option<JoinFilter>,
115 pub join_type: JoinType,
117 schema: SchemaRef,
119 metrics: ExecutionPlanMetricsSet,
121 left_sort_exprs: LexOrdering,
123 right_sort_exprs: LexOrdering,
125 pub sort_options: Vec<SortOptions>,
127 pub null_equality: NullEquality,
129 cache: Arc<PlanProperties>,
131}
132
133impl SortMergeJoinExec {
134 pub fn try_new(
139 left: Arc<dyn ExecutionPlan>,
140 right: Arc<dyn ExecutionPlan>,
141 on: JoinOn,
142 filter: Option<JoinFilter>,
143 join_type: JoinType,
144 sort_options: Vec<SortOptions>,
145 null_equality: NullEquality,
146 ) -> Result<Self> {
147 let left_schema = left.schema();
148 let right_schema = right.schema();
149
150 check_join_is_valid(&left_schema, &right_schema, &on)?;
151 if sort_options.len() != on.len() {
152 return plan_err!(
153 "Expected number of sort options: {}, actual: {}",
154 on.len(),
155 sort_options.len()
156 );
157 }
158
159 let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
160 .iter()
161 .zip(sort_options.iter())
162 .map(|((l, r), sort_op)| {
163 let left = PhysicalSortExpr {
164 expr: Arc::clone(l),
165 options: *sort_op,
166 };
167 let right = PhysicalSortExpr {
168 expr: Arc::clone(r),
169 options: *sort_op,
170 };
171 (left, right)
172 })
173 .unzip();
174 let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
175 return plan_err!(
176 "SortMergeJoinExec requires valid sort expressions for its left side"
177 );
178 };
179 let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
180 return plan_err!(
181 "SortMergeJoinExec requires valid sort expressions for its right side"
182 );
183 };
184
185 let schema =
186 Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
187 let cache =
188 Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
189 Ok(Self {
190 left,
191 right,
192 on,
193 filter,
194 join_type,
195 schema,
196 metrics: ExecutionPlanMetricsSet::new(),
197 left_sort_exprs,
198 right_sort_exprs,
199 sort_options,
200 null_equality,
201 cache: Arc::new(cache),
202 })
203 }
204
205 pub fn probe_side(join_type: &JoinType) -> JoinSide {
208 match join_type {
211 JoinType::Right
213 | JoinType::RightSemi
214 | JoinType::RightAnti
215 | JoinType::RightMark => JoinSide::Right,
216 JoinType::Inner
217 | JoinType::Left
218 | JoinType::Full
219 | JoinType::LeftAnti
220 | JoinType::LeftSemi
221 | JoinType::LeftMark => JoinSide::Left,
222 }
223 }
224
225 fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
227 match join_type {
228 JoinType::Inner => vec![true, false],
229 JoinType::Left
230 | JoinType::LeftSemi
231 | JoinType::LeftAnti
232 | JoinType::LeftMark => vec![true, false],
233 JoinType::Right
234 | JoinType::RightSemi
235 | JoinType::RightAnti
236 | JoinType::RightMark => {
237 vec![false, true]
238 }
239 _ => vec![false, false],
240 }
241 }
242
243 pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
245 &self.on
246 }
247
248 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
250 &self.right
251 }
252
253 pub fn join_type(&self) -> JoinType {
255 self.join_type
256 }
257
258 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
260 &self.left
261 }
262
263 pub fn filter(&self) -> &Option<JoinFilter> {
265 &self.filter
266 }
267
268 pub fn sort_options(&self) -> &[SortOptions] {
270 &self.sort_options
271 }
272
273 pub fn null_equality(&self) -> NullEquality {
275 self.null_equality
276 }
277
278 fn compute_properties(
280 left: &Arc<dyn ExecutionPlan>,
281 right: &Arc<dyn ExecutionPlan>,
282 schema: SchemaRef,
283 join_type: JoinType,
284 join_on: JoinOnRef,
285 ) -> Result<PlanProperties> {
286 let eq_properties = join_equivalence_properties(
288 left.equivalence_properties().clone(),
289 right.equivalence_properties().clone(),
290 &join_type,
291 schema,
292 &Self::maintains_input_order(join_type),
293 Some(Self::probe_side(&join_type)),
294 join_on,
295 )?;
296
297 let output_partitioning =
298 symmetric_join_output_partitioning(left, right, &join_type)?;
299
300 Ok(PlanProperties::new(
301 eq_properties,
302 output_partitioning,
303 EmissionType::Incremental,
304 boundedness_from_children([left, right]),
305 ))
306 }
307
308 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
314 let left = self.left();
315 let right = self.right();
316 let new_join = SortMergeJoinExec::try_new(
317 Arc::clone(right),
318 Arc::clone(left),
319 self.on()
320 .iter()
321 .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
322 .collect::<Vec<_>>(),
323 self.filter().as_ref().map(JoinFilter::swap),
324 self.join_type().swap(),
325 self.sort_options.clone(),
326 self.null_equality,
327 )?;
328
329 if matches!(
332 self.join_type(),
333 JoinType::LeftSemi
334 | JoinType::RightSemi
335 | JoinType::LeftAnti
336 | JoinType::RightAnti
337 ) {
338 Ok(Arc::new(new_join))
339 } else {
340 reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
341 }
342 }
343
344 fn with_new_children_and_same_properties(
345 &self,
346 mut children: Vec<Arc<dyn ExecutionPlan>>,
347 ) -> Self {
348 let left = children.swap_remove(0);
349 let right = children.swap_remove(0);
350 Self {
351 left,
352 right,
353 metrics: ExecutionPlanMetricsSet::new(),
354 ..Self::clone(self)
355 }
356 }
357}
358
359impl DisplayAs for SortMergeJoinExec {
360 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
361 match t {
362 DisplayFormatType::Default | DisplayFormatType::Verbose => {
363 let on = self
364 .on
365 .iter()
366 .map(|(c1, c2)| format!("({c1}, {c2})"))
367 .collect::<Vec<String>>()
368 .join(", ");
369 let display_null_equality =
370 if self.null_equality() == NullEquality::NullEqualsNull {
371 ", NullsEqual: true"
372 } else {
373 ""
374 };
375 write!(
376 f,
377 "{}: join_type={:?}, on=[{}]{}{}",
378 Self::static_name(),
379 self.join_type,
380 on,
381 self.filter.as_ref().map_or_else(
382 || "".to_string(),
383 |f| format!(", filter={}", f.expression())
384 ),
385 display_null_equality,
386 )
387 }
388 DisplayFormatType::TreeRender => {
389 let on = self
390 .on
391 .iter()
392 .map(|(c1, c2)| {
393 format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
394 })
395 .collect::<Vec<String>>()
396 .join(", ");
397
398 if self.join_type() != JoinType::Inner {
399 writeln!(f, "join_type={:?}", self.join_type)?;
400 }
401 writeln!(f, "on={on}")?;
402
403 if self.null_equality() == NullEquality::NullEqualsNull {
404 writeln!(f, "NullsEqual: true")?;
405 }
406
407 Ok(())
408 }
409 }
410 }
411}
412
413impl ExecutionPlan for SortMergeJoinExec {
414 fn name(&self) -> &'static str {
415 "SortMergeJoinExec"
416 }
417
418 fn as_any(&self) -> &dyn Any {
419 self
420 }
421
422 fn properties(&self) -> &Arc<PlanProperties> {
423 &self.cache
424 }
425
426 fn required_input_distribution(&self) -> Vec<Distribution> {
427 let (left_expr, right_expr) = self
428 .on
429 .iter()
430 .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
431 .unzip();
432 vec![
433 Distribution::HashPartitioned(left_expr),
434 Distribution::HashPartitioned(right_expr),
435 ]
436 }
437
438 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
439 vec![
440 Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
441 Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
442 ]
443 }
444
445 fn maintains_input_order(&self) -> Vec<bool> {
446 Self::maintains_input_order(self.join_type)
447 }
448
449 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
450 vec![&self.left, &self.right]
451 }
452
453 fn with_new_children(
454 self: Arc<Self>,
455 children: Vec<Arc<dyn ExecutionPlan>>,
456 ) -> Result<Arc<dyn ExecutionPlan>> {
457 check_if_same_properties!(self, children);
458 match &children[..] {
459 [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
460 Arc::clone(left),
461 Arc::clone(right),
462 self.on.clone(),
463 self.filter.clone(),
464 self.join_type,
465 self.sort_options.clone(),
466 self.null_equality,
467 )?)),
468 _ => internal_err!("SortMergeJoin wrong number of children"),
469 }
470 }
471
472 fn execute(
473 &self,
474 partition: usize,
475 context: Arc<TaskContext>,
476 ) -> Result<SendableRecordBatchStream> {
477 let left_partitions = self.left.output_partitioning().partition_count();
478 let right_partitions = self.right.output_partitioning().partition_count();
479 assert_eq_or_internal_err!(
480 left_partitions,
481 right_partitions,
482 "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
483 consider using RepartitionExec"
484 );
485 let (on_left, on_right) = self.on.iter().cloned().unzip();
486 let (streamed, buffered, on_streamed, on_buffered) =
487 if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
488 (
489 Arc::clone(&self.left),
490 Arc::clone(&self.right),
491 on_left,
492 on_right,
493 )
494 } else {
495 (
496 Arc::clone(&self.right),
497 Arc::clone(&self.left),
498 on_right,
499 on_left,
500 )
501 };
502
503 let streamed = streamed.execute(partition, Arc::clone(&context))?;
505 let buffered = buffered.execute(partition, Arc::clone(&context))?;
506
507 let batch_size = context.session_config().batch_size();
509
510 let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
512 .register(context.memory_pool());
513
514 Ok(Box::pin(SortMergeJoinStream::try_new(
516 context.session_config().spill_compression(),
517 Arc::clone(&self.schema),
518 self.sort_options.clone(),
519 self.null_equality,
520 streamed,
521 buffered,
522 on_streamed,
523 on_buffered,
524 self.filter.clone(),
525 self.join_type,
526 batch_size,
527 SortMergeJoinMetrics::new(partition, &self.metrics),
528 reservation,
529 context.runtime_env(),
530 )?))
531 }
532
533 fn metrics(&self) -> Option<MetricsSet> {
534 Some(self.metrics.clone_inner())
535 }
536
537 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
538 estimate_join_statistics(
550 self.left.partition_statistics(partition)?,
551 self.right.partition_statistics(partition)?,
552 &self.on,
553 &self.join_type,
554 &self.schema,
555 )
556 }
557
558 fn try_swapping_with_projection(
562 &self,
563 projection: &ProjectionExec,
564 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
565 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
567 else {
568 return Ok(None);
569 };
570
571 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
572 self.left().schema().fields().len(),
573 &projection_as_columns,
574 );
575
576 if !join_allows_pushdown(
577 &projection_as_columns,
578 &self.schema(),
579 far_right_left_col_ind,
580 far_left_right_col_ind,
581 ) {
582 return Ok(None);
583 }
584
585 let Some(new_on) = update_join_on(
586 &projection_as_columns[0..=far_right_left_col_ind as _],
587 &projection_as_columns[far_left_right_col_ind as _..],
588 self.on(),
589 self.left().schema().fields().len(),
590 ) else {
591 return Ok(None);
592 };
593
594 let (new_left, new_right) = new_join_children(
595 &projection_as_columns,
596 far_right_left_col_ind,
597 far_left_right_col_ind,
598 self.children()[0],
599 self.children()[1],
600 )?;
601
602 Ok(Some(Arc::new(SortMergeJoinExec::try_new(
603 Arc::new(new_left),
604 Arc::new(new_right),
605 new_on,
606 self.filter.clone(),
607 self.join_type,
608 self.sort_options.clone(),
609 self.null_equality,
610 )?)))
611 }
612}