1use std::any::Any;
23use std::fmt::Formatter;
24use std::sync::Arc;
25
26use crate::execution_plan::{boundedness_from_children, EmissionType};
27use crate::expressions::PhysicalSortExpr;
28use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
29use crate::joins::sort_merge_join::stream::SortMergeJoinStream;
30use crate::joins::utils::{
31 build_join_schema, check_join_is_valid, estimate_join_statistics,
32 reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn,
33 JoinOnRef,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
36use crate::projection::{
37 join_allows_pushdown, join_table_borders, new_join_children,
38 physical_to_column_exprs, update_join_on, ProjectionExec,
39};
40use crate::{
41 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
42 PlanProperties, SendableRecordBatchStream, Statistics,
43};
44
45use arrow::compute::SortOptions;
46use arrow::datatypes::SchemaRef;
47use datafusion_common::{
48 internal_err, plan_err, JoinSide, JoinType, NullEquality, Result,
49};
50use datafusion_execution::memory_pool::MemoryConsumer;
51use datafusion_execution::TaskContext;
52use datafusion_physical_expr::equivalence::join_equivalence_properties;
53use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef};
54use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
55
56#[derive(Debug, Clone)]
105pub struct SortMergeJoinExec {
106 pub left: Arc<dyn ExecutionPlan>,
108 pub right: Arc<dyn ExecutionPlan>,
110 pub on: JoinOn,
112 pub filter: Option<JoinFilter>,
114 pub join_type: JoinType,
116 schema: SchemaRef,
118 metrics: ExecutionPlanMetricsSet,
120 left_sort_exprs: LexOrdering,
122 right_sort_exprs: LexOrdering,
124 pub sort_options: Vec<SortOptions>,
126 pub null_equality: NullEquality,
128 cache: PlanProperties,
130}
131
132impl SortMergeJoinExec {
133 pub fn try_new(
138 left: Arc<dyn ExecutionPlan>,
139 right: Arc<dyn ExecutionPlan>,
140 on: JoinOn,
141 filter: Option<JoinFilter>,
142 join_type: JoinType,
143 sort_options: Vec<SortOptions>,
144 null_equality: NullEquality,
145 ) -> Result<Self> {
146 let left_schema = left.schema();
147 let right_schema = right.schema();
148
149 check_join_is_valid(&left_schema, &right_schema, &on)?;
150 if sort_options.len() != on.len() {
151 return plan_err!(
152 "Expected number of sort options: {}, actual: {}",
153 on.len(),
154 sort_options.len()
155 );
156 }
157
158 let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
159 .iter()
160 .zip(sort_options.iter())
161 .map(|((l, r), sort_op)| {
162 let left = PhysicalSortExpr {
163 expr: Arc::clone(l),
164 options: *sort_op,
165 };
166 let right = PhysicalSortExpr {
167 expr: Arc::clone(r),
168 options: *sort_op,
169 };
170 (left, right)
171 })
172 .unzip();
173 let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
174 return plan_err!(
175 "SortMergeJoinExec requires valid sort expressions for its left side"
176 );
177 };
178 let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
179 return plan_err!(
180 "SortMergeJoinExec requires valid sort expressions for its right side"
181 );
182 };
183
184 let schema =
185 Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
186 let cache =
187 Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
188 Ok(Self {
189 left,
190 right,
191 on,
192 filter,
193 join_type,
194 schema,
195 metrics: ExecutionPlanMetricsSet::new(),
196 left_sort_exprs,
197 right_sort_exprs,
198 sort_options,
199 null_equality,
200 cache,
201 })
202 }
203
204 pub fn probe_side(join_type: &JoinType) -> JoinSide {
207 match join_type {
210 JoinType::Right
212 | JoinType::RightSemi
213 | JoinType::RightAnti
214 | JoinType::RightMark => JoinSide::Right,
215 JoinType::Inner
216 | JoinType::Left
217 | JoinType::Full
218 | JoinType::LeftAnti
219 | JoinType::LeftSemi
220 | JoinType::LeftMark => JoinSide::Left,
221 }
222 }
223
224 fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
226 match join_type {
227 JoinType::Inner => vec![true, false],
228 JoinType::Left
229 | JoinType::LeftSemi
230 | JoinType::LeftAnti
231 | JoinType::LeftMark => vec![true, false],
232 JoinType::Right
233 | JoinType::RightSemi
234 | JoinType::RightAnti
235 | JoinType::RightMark => {
236 vec![false, true]
237 }
238 _ => vec![false, false],
239 }
240 }
241
242 pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
244 &self.on
245 }
246
247 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
249 &self.right
250 }
251
252 pub fn join_type(&self) -> JoinType {
254 self.join_type
255 }
256
257 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
259 &self.left
260 }
261
262 pub fn filter(&self) -> &Option<JoinFilter> {
264 &self.filter
265 }
266
267 pub fn sort_options(&self) -> &[SortOptions] {
269 &self.sort_options
270 }
271
272 pub fn null_equality(&self) -> NullEquality {
274 self.null_equality
275 }
276
277 fn compute_properties(
279 left: &Arc<dyn ExecutionPlan>,
280 right: &Arc<dyn ExecutionPlan>,
281 schema: SchemaRef,
282 join_type: JoinType,
283 join_on: JoinOnRef,
284 ) -> Result<PlanProperties> {
285 let eq_properties = join_equivalence_properties(
287 left.equivalence_properties().clone(),
288 right.equivalence_properties().clone(),
289 &join_type,
290 schema,
291 &Self::maintains_input_order(join_type),
292 Some(Self::probe_side(&join_type)),
293 join_on,
294 )?;
295
296 let output_partitioning =
297 symmetric_join_output_partitioning(left, right, &join_type)?;
298
299 Ok(PlanProperties::new(
300 eq_properties,
301 output_partitioning,
302 EmissionType::Incremental,
303 boundedness_from_children([left, right]),
304 ))
305 }
306
307 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
313 let left = self.left();
314 let right = self.right();
315 let new_join = SortMergeJoinExec::try_new(
316 Arc::clone(right),
317 Arc::clone(left),
318 self.on()
319 .iter()
320 .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
321 .collect::<Vec<_>>(),
322 self.filter().as_ref().map(JoinFilter::swap),
323 self.join_type().swap(),
324 self.sort_options.clone(),
325 self.null_equality,
326 )?;
327
328 if matches!(
331 self.join_type(),
332 JoinType::LeftSemi
333 | JoinType::RightSemi
334 | JoinType::LeftAnti
335 | JoinType::RightAnti
336 ) {
337 Ok(Arc::new(new_join))
338 } else {
339 reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
340 }
341 }
342}
343
344impl DisplayAs for SortMergeJoinExec {
345 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
346 match t {
347 DisplayFormatType::Default | DisplayFormatType::Verbose => {
348 let on = self
349 .on
350 .iter()
351 .map(|(c1, c2)| format!("({c1}, {c2})"))
352 .collect::<Vec<String>>()
353 .join(", ");
354 let display_null_equality =
355 if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
356 ", NullsEqual: true"
357 } else {
358 ""
359 };
360 write!(
361 f,
362 "SortMergeJoin: join_type={:?}, on=[{}]{}{}",
363 self.join_type,
364 on,
365 self.filter.as_ref().map_or_else(
366 || "".to_string(),
367 |f| format!(", filter={}", f.expression())
368 ),
369 display_null_equality,
370 )
371 }
372 DisplayFormatType::TreeRender => {
373 let on = self
374 .on
375 .iter()
376 .map(|(c1, c2)| {
377 format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
378 })
379 .collect::<Vec<String>>()
380 .join(", ");
381
382 if self.join_type() != JoinType::Inner {
383 writeln!(f, "join_type={:?}", self.join_type)?;
384 }
385 writeln!(f, "on={on}")?;
386
387 if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
388 writeln!(f, "NullsEqual: true")?;
389 }
390
391 Ok(())
392 }
393 }
394 }
395}
396
397impl ExecutionPlan for SortMergeJoinExec {
398 fn name(&self) -> &'static str {
399 "SortMergeJoinExec"
400 }
401
402 fn as_any(&self) -> &dyn Any {
403 self
404 }
405
406 fn properties(&self) -> &PlanProperties {
407 &self.cache
408 }
409
410 fn required_input_distribution(&self) -> Vec<Distribution> {
411 let (left_expr, right_expr) = self
412 .on
413 .iter()
414 .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
415 .unzip();
416 vec![
417 Distribution::HashPartitioned(left_expr),
418 Distribution::HashPartitioned(right_expr),
419 ]
420 }
421
422 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
423 vec![
424 Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
425 Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
426 ]
427 }
428
429 fn maintains_input_order(&self) -> Vec<bool> {
430 Self::maintains_input_order(self.join_type)
431 }
432
433 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
434 vec![&self.left, &self.right]
435 }
436
437 fn with_new_children(
438 self: Arc<Self>,
439 children: Vec<Arc<dyn ExecutionPlan>>,
440 ) -> Result<Arc<dyn ExecutionPlan>> {
441 match &children[..] {
442 [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
443 Arc::clone(left),
444 Arc::clone(right),
445 self.on.clone(),
446 self.filter.clone(),
447 self.join_type,
448 self.sort_options.clone(),
449 self.null_equality,
450 )?)),
451 _ => internal_err!("SortMergeJoin wrong number of children"),
452 }
453 }
454
455 fn execute(
456 &self,
457 partition: usize,
458 context: Arc<TaskContext>,
459 ) -> Result<SendableRecordBatchStream> {
460 let left_partitions = self.left.output_partitioning().partition_count();
461 let right_partitions = self.right.output_partitioning().partition_count();
462 if left_partitions != right_partitions {
463 return internal_err!(
464 "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
465 consider using RepartitionExec"
466 );
467 }
468 let (on_left, on_right) = self.on.iter().cloned().unzip();
469 let (streamed, buffered, on_streamed, on_buffered) =
470 if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
471 (
472 Arc::clone(&self.left),
473 Arc::clone(&self.right),
474 on_left,
475 on_right,
476 )
477 } else {
478 (
479 Arc::clone(&self.right),
480 Arc::clone(&self.left),
481 on_right,
482 on_left,
483 )
484 };
485
486 let streamed = streamed.execute(partition, Arc::clone(&context))?;
488 let buffered = buffered.execute(partition, Arc::clone(&context))?;
489
490 let batch_size = context.session_config().batch_size();
492
493 let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
495 .register(context.memory_pool());
496
497 Ok(Box::pin(SortMergeJoinStream::try_new(
499 context.session_config().spill_compression(),
500 Arc::clone(&self.schema),
501 self.sort_options.clone(),
502 self.null_equality,
503 streamed,
504 buffered,
505 on_streamed,
506 on_buffered,
507 self.filter.clone(),
508 self.join_type,
509 batch_size,
510 SortMergeJoinMetrics::new(partition, &self.metrics),
511 reservation,
512 context.runtime_env(),
513 )?))
514 }
515
516 fn metrics(&self) -> Option<MetricsSet> {
517 Some(self.metrics.clone_inner())
518 }
519
520 fn statistics(&self) -> Result<Statistics> {
521 self.partition_statistics(None)
522 }
523
524 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
525 if partition.is_some() {
526 return Ok(Statistics::new_unknown(&self.schema()));
527 }
528 estimate_join_statistics(
532 self.left.partition_statistics(None)?,
533 self.right.partition_statistics(None)?,
534 self.on.clone(),
535 &self.join_type,
536 &self.schema,
537 )
538 }
539
540 fn try_swapping_with_projection(
544 &self,
545 projection: &ProjectionExec,
546 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
547 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
549 else {
550 return Ok(None);
551 };
552
553 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
554 self.left().schema().fields().len(),
555 &projection_as_columns,
556 );
557
558 if !join_allows_pushdown(
559 &projection_as_columns,
560 &self.schema(),
561 far_right_left_col_ind,
562 far_left_right_col_ind,
563 ) {
564 return Ok(None);
565 }
566
567 let Some(new_on) = update_join_on(
568 &projection_as_columns[0..=far_right_left_col_ind as _],
569 &projection_as_columns[far_left_right_col_ind as _..],
570 self.on(),
571 self.left().schema().fields().len(),
572 ) else {
573 return Ok(None);
574 };
575
576 let (new_left, new_right) = new_join_children(
577 &projection_as_columns,
578 far_right_left_col_ind,
579 far_left_right_col_ind,
580 self.children()[0],
581 self.children()[1],
582 )?;
583
584 Ok(Some(Arc::new(SortMergeJoinExec::try_new(
585 Arc::new(new_left),
586 Arc::new(new_right),
587 new_on,
588 self.filter.clone(),
589 self.join_type,
590 self.sort_options.clone(),
591 self.null_equality,
592 )?)))
593 }
594}