1use std::any::Any;
23use std::fmt::Formatter;
24use std::sync::Arc;
25
26use crate::execution_plan::{EmissionType, boundedness_from_children};
27use crate::expressions::PhysicalSortExpr;
28use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
29use crate::joins::sort_merge_join::stream::SortMergeJoinStream;
30use crate::joins::utils::{
31 JoinFilter, JoinOn, JoinOnRef, build_join_schema, check_join_is_valid,
32 estimate_join_statistics, reorder_output_after_swap,
33 symmetric_join_output_partitioning,
34};
35use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet};
36use crate::projection::{
37 ProjectionExec, join_allows_pushdown, join_table_borders, new_join_children,
38 physical_to_column_exprs, update_join_on,
39};
40use crate::{
41 DisplayAs, DisplayFormatType, Distribution, ExecutionPlan, ExecutionPlanProperties,
42 PlanProperties, SendableRecordBatchStream, Statistics,
43};
44
45use arrow::compute::SortOptions;
46use arrow::datatypes::SchemaRef;
47use datafusion_common::{
48 JoinSide, JoinType, NullEquality, Result, assert_eq_or_internal_err, internal_err,
49 plan_err,
50};
51use datafusion_execution::TaskContext;
52use datafusion_execution::memory_pool::MemoryConsumer;
53use datafusion_physical_expr::equivalence::join_equivalence_properties;
54use datafusion_physical_expr_common::physical_expr::{PhysicalExprRef, fmt_sql};
55use datafusion_physical_expr_common::sort_expr::{LexOrdering, OrderingRequirements};
56
57#[derive(Debug, Clone)]
106pub struct SortMergeJoinExec {
107 pub left: Arc<dyn ExecutionPlan>,
109 pub right: Arc<dyn ExecutionPlan>,
111 pub on: JoinOn,
113 pub filter: Option<JoinFilter>,
115 pub join_type: JoinType,
117 schema: SchemaRef,
119 metrics: ExecutionPlanMetricsSet,
121 left_sort_exprs: LexOrdering,
123 right_sort_exprs: LexOrdering,
125 pub sort_options: Vec<SortOptions>,
127 pub null_equality: NullEquality,
129 cache: PlanProperties,
131}
132
133impl SortMergeJoinExec {
134 pub fn try_new(
139 left: Arc<dyn ExecutionPlan>,
140 right: Arc<dyn ExecutionPlan>,
141 on: JoinOn,
142 filter: Option<JoinFilter>,
143 join_type: JoinType,
144 sort_options: Vec<SortOptions>,
145 null_equality: NullEquality,
146 ) -> Result<Self> {
147 let left_schema = left.schema();
148 let right_schema = right.schema();
149
150 check_join_is_valid(&left_schema, &right_schema, &on)?;
151 if sort_options.len() != on.len() {
152 return plan_err!(
153 "Expected number of sort options: {}, actual: {}",
154 on.len(),
155 sort_options.len()
156 );
157 }
158
159 let (left_sort_exprs, right_sort_exprs): (Vec<_>, Vec<_>) = on
160 .iter()
161 .zip(sort_options.iter())
162 .map(|((l, r), sort_op)| {
163 let left = PhysicalSortExpr {
164 expr: Arc::clone(l),
165 options: *sort_op,
166 };
167 let right = PhysicalSortExpr {
168 expr: Arc::clone(r),
169 options: *sort_op,
170 };
171 (left, right)
172 })
173 .unzip();
174 let Some(left_sort_exprs) = LexOrdering::new(left_sort_exprs) else {
175 return plan_err!(
176 "SortMergeJoinExec requires valid sort expressions for its left side"
177 );
178 };
179 let Some(right_sort_exprs) = LexOrdering::new(right_sort_exprs) else {
180 return plan_err!(
181 "SortMergeJoinExec requires valid sort expressions for its right side"
182 );
183 };
184
185 let schema =
186 Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0);
187 let cache =
188 Self::compute_properties(&left, &right, Arc::clone(&schema), join_type, &on)?;
189 Ok(Self {
190 left,
191 right,
192 on,
193 filter,
194 join_type,
195 schema,
196 metrics: ExecutionPlanMetricsSet::new(),
197 left_sort_exprs,
198 right_sort_exprs,
199 sort_options,
200 null_equality,
201 cache,
202 })
203 }
204
205 pub fn probe_side(join_type: &JoinType) -> JoinSide {
208 match join_type {
211 JoinType::Right
213 | JoinType::RightSemi
214 | JoinType::RightAnti
215 | JoinType::RightMark => JoinSide::Right,
216 JoinType::Inner
217 | JoinType::Left
218 | JoinType::Full
219 | JoinType::LeftAnti
220 | JoinType::LeftSemi
221 | JoinType::LeftMark => JoinSide::Left,
222 }
223 }
224
225 fn maintains_input_order(join_type: JoinType) -> Vec<bool> {
227 match join_type {
228 JoinType::Inner => vec![true, false],
229 JoinType::Left
230 | JoinType::LeftSemi
231 | JoinType::LeftAnti
232 | JoinType::LeftMark => vec![true, false],
233 JoinType::Right
234 | JoinType::RightSemi
235 | JoinType::RightAnti
236 | JoinType::RightMark => {
237 vec![false, true]
238 }
239 _ => vec![false, false],
240 }
241 }
242
243 pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
245 &self.on
246 }
247
248 pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
250 &self.right
251 }
252
253 pub fn join_type(&self) -> JoinType {
255 self.join_type
256 }
257
258 pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
260 &self.left
261 }
262
263 pub fn filter(&self) -> &Option<JoinFilter> {
265 &self.filter
266 }
267
268 pub fn sort_options(&self) -> &[SortOptions] {
270 &self.sort_options
271 }
272
273 pub fn null_equality(&self) -> NullEquality {
275 self.null_equality
276 }
277
278 fn compute_properties(
280 left: &Arc<dyn ExecutionPlan>,
281 right: &Arc<dyn ExecutionPlan>,
282 schema: SchemaRef,
283 join_type: JoinType,
284 join_on: JoinOnRef,
285 ) -> Result<PlanProperties> {
286 let eq_properties = join_equivalence_properties(
288 left.equivalence_properties().clone(),
289 right.equivalence_properties().clone(),
290 &join_type,
291 schema,
292 &Self::maintains_input_order(join_type),
293 Some(Self::probe_side(&join_type)),
294 join_on,
295 )?;
296
297 let output_partitioning =
298 symmetric_join_output_partitioning(left, right, &join_type)?;
299
300 Ok(PlanProperties::new(
301 eq_properties,
302 output_partitioning,
303 EmissionType::Incremental,
304 boundedness_from_children([left, right]),
305 ))
306 }
307
308 pub fn swap_inputs(&self) -> Result<Arc<dyn ExecutionPlan>> {
314 let left = self.left();
315 let right = self.right();
316 let new_join = SortMergeJoinExec::try_new(
317 Arc::clone(right),
318 Arc::clone(left),
319 self.on()
320 .iter()
321 .map(|(l, r)| (Arc::clone(r), Arc::clone(l)))
322 .collect::<Vec<_>>(),
323 self.filter().as_ref().map(JoinFilter::swap),
324 self.join_type().swap(),
325 self.sort_options.clone(),
326 self.null_equality,
327 )?;
328
329 if matches!(
332 self.join_type(),
333 JoinType::LeftSemi
334 | JoinType::RightSemi
335 | JoinType::LeftAnti
336 | JoinType::RightAnti
337 ) {
338 Ok(Arc::new(new_join))
339 } else {
340 reorder_output_after_swap(Arc::new(new_join), &left.schema(), &right.schema())
341 }
342 }
343}
344
345impl DisplayAs for SortMergeJoinExec {
346 fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
347 match t {
348 DisplayFormatType::Default | DisplayFormatType::Verbose => {
349 let on = self
350 .on
351 .iter()
352 .map(|(c1, c2)| format!("({c1}, {c2})"))
353 .collect::<Vec<String>>()
354 .join(", ");
355 let display_null_equality =
356 if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
357 ", NullsEqual: true"
358 } else {
359 ""
360 };
361 write!(
362 f,
363 "{}: join_type={:?}, on=[{}]{}{}",
364 Self::static_name(),
365 self.join_type,
366 on,
367 self.filter.as_ref().map_or_else(
368 || "".to_string(),
369 |f| format!(", filter={}", f.expression())
370 ),
371 display_null_equality,
372 )
373 }
374 DisplayFormatType::TreeRender => {
375 let on = self
376 .on
377 .iter()
378 .map(|(c1, c2)| {
379 format!("({} = {})", fmt_sql(c1.as_ref()), fmt_sql(c2.as_ref()))
380 })
381 .collect::<Vec<String>>()
382 .join(", ");
383
384 if self.join_type() != JoinType::Inner {
385 writeln!(f, "join_type={:?}", self.join_type)?;
386 }
387 writeln!(f, "on={on}")?;
388
389 if matches!(self.null_equality(), NullEquality::NullEqualsNull) {
390 writeln!(f, "NullsEqual: true")?;
391 }
392
393 Ok(())
394 }
395 }
396 }
397}
398
399impl ExecutionPlan for SortMergeJoinExec {
400 fn name(&self) -> &'static str {
401 "SortMergeJoinExec"
402 }
403
404 fn as_any(&self) -> &dyn Any {
405 self
406 }
407
408 fn properties(&self) -> &PlanProperties {
409 &self.cache
410 }
411
412 fn required_input_distribution(&self) -> Vec<Distribution> {
413 let (left_expr, right_expr) = self
414 .on
415 .iter()
416 .map(|(l, r)| (Arc::clone(l), Arc::clone(r)))
417 .unzip();
418 vec![
419 Distribution::HashPartitioned(left_expr),
420 Distribution::HashPartitioned(right_expr),
421 ]
422 }
423
424 fn required_input_ordering(&self) -> Vec<Option<OrderingRequirements>> {
425 vec![
426 Some(OrderingRequirements::from(self.left_sort_exprs.clone())),
427 Some(OrderingRequirements::from(self.right_sort_exprs.clone())),
428 ]
429 }
430
431 fn maintains_input_order(&self) -> Vec<bool> {
432 Self::maintains_input_order(self.join_type)
433 }
434
435 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
436 vec![&self.left, &self.right]
437 }
438
439 fn with_new_children(
440 self: Arc<Self>,
441 children: Vec<Arc<dyn ExecutionPlan>>,
442 ) -> Result<Arc<dyn ExecutionPlan>> {
443 match &children[..] {
444 [left, right] => Ok(Arc::new(SortMergeJoinExec::try_new(
445 Arc::clone(left),
446 Arc::clone(right),
447 self.on.clone(),
448 self.filter.clone(),
449 self.join_type,
450 self.sort_options.clone(),
451 self.null_equality,
452 )?)),
453 _ => internal_err!("SortMergeJoin wrong number of children"),
454 }
455 }
456
457 fn execute(
458 &self,
459 partition: usize,
460 context: Arc<TaskContext>,
461 ) -> Result<SendableRecordBatchStream> {
462 let left_partitions = self.left.output_partitioning().partition_count();
463 let right_partitions = self.right.output_partitioning().partition_count();
464 assert_eq_or_internal_err!(
465 left_partitions,
466 right_partitions,
467 "Invalid SortMergeJoinExec, partition count mismatch {left_partitions}!={right_partitions},\
468 consider using RepartitionExec"
469 );
470 let (on_left, on_right) = self.on.iter().cloned().unzip();
471 let (streamed, buffered, on_streamed, on_buffered) =
472 if SortMergeJoinExec::probe_side(&self.join_type) == JoinSide::Left {
473 (
474 Arc::clone(&self.left),
475 Arc::clone(&self.right),
476 on_left,
477 on_right,
478 )
479 } else {
480 (
481 Arc::clone(&self.right),
482 Arc::clone(&self.left),
483 on_right,
484 on_left,
485 )
486 };
487
488 let streamed = streamed.execute(partition, Arc::clone(&context))?;
490 let buffered = buffered.execute(partition, Arc::clone(&context))?;
491
492 let batch_size = context.session_config().batch_size();
494
495 let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
497 .register(context.memory_pool());
498
499 Ok(Box::pin(SortMergeJoinStream::try_new(
501 context.session_config().spill_compression(),
502 Arc::clone(&self.schema),
503 self.sort_options.clone(),
504 self.null_equality,
505 streamed,
506 buffered,
507 on_streamed,
508 on_buffered,
509 self.filter.clone(),
510 self.join_type,
511 batch_size,
512 SortMergeJoinMetrics::new(partition, &self.metrics),
513 reservation,
514 context.runtime_env(),
515 )?))
516 }
517
518 fn metrics(&self) -> Option<MetricsSet> {
519 Some(self.metrics.clone_inner())
520 }
521
522 fn statistics(&self) -> Result<Statistics> {
523 self.partition_statistics(None)
524 }
525
526 fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> {
527 estimate_join_statistics(
539 self.left.partition_statistics(partition)?,
540 self.right.partition_statistics(partition)?,
541 &self.on,
542 &self.join_type,
543 &self.schema,
544 )
545 }
546
547 fn try_swapping_with_projection(
551 &self,
552 projection: &ProjectionExec,
553 ) -> Result<Option<Arc<dyn ExecutionPlan>>> {
554 let Some(projection_as_columns) = physical_to_column_exprs(projection.expr())
556 else {
557 return Ok(None);
558 };
559
560 let (far_right_left_col_ind, far_left_right_col_ind) = join_table_borders(
561 self.left().schema().fields().len(),
562 &projection_as_columns,
563 );
564
565 if !join_allows_pushdown(
566 &projection_as_columns,
567 &self.schema(),
568 far_right_left_col_ind,
569 far_left_right_col_ind,
570 ) {
571 return Ok(None);
572 }
573
574 let Some(new_on) = update_join_on(
575 &projection_as_columns[0..=far_right_left_col_ind as _],
576 &projection_as_columns[far_left_right_col_ind as _..],
577 self.on(),
578 self.left().schema().fields().len(),
579 ) else {
580 return Ok(None);
581 };
582
583 let (new_left, new_right) = new_join_children(
584 &projection_as_columns,
585 far_right_left_col_ind,
586 far_left_right_col_ind,
587 self.children()[0],
588 self.children()[1],
589 )?;
590
591 Ok(Some(Arc::new(SortMergeJoinExec::try_new(
592 Arc::new(new_left),
593 Arc::new(new_right),
594 new_on,
595 self.filter.clone(),
596 self.join_type,
597 self.sort_options.clone(),
598 self.null_equality,
599 )?)))
600 }
601}