Skip to main content

datafusion_physical_optimizer/
projection_pushdown.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This file implements the `ProjectionPushdown` physical optimization rule.
19//! The function [`remove_unnecessary_projections`] tries to push down all
20//! projections one by one if the operator below is amenable to this. If a
21//! projection reaches a source, it can even disappear from the plan entirely.
22
23use crate::PhysicalOptimizerRule;
24use arrow::datatypes::{Fields, Schema, SchemaRef};
25use datafusion_common::alias::AliasGenerator;
26use std::collections::HashSet;
27use std::sync::Arc;
28
29use datafusion_common::config::ConfigOptions;
30use datafusion_common::tree_node::{
31    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
32};
33use datafusion_common::{JoinSide, JoinType, Result};
34use datafusion_physical_expr::expressions::Column;
35use datafusion_physical_expr_common::physical_expr::{PhysicalExpr, is_volatile};
36use datafusion_physical_plan::ExecutionPlan;
37use datafusion_physical_plan::joins::NestedLoopJoinExec;
38use datafusion_physical_plan::joins::utils::{ColumnIndex, JoinFilter};
39use datafusion_physical_plan::projection::{
40    ProjectionExec, remove_unnecessary_projections,
41};
42
43/// This rule inspects `ProjectionExec`'s in the given physical plan and tries to
44/// remove or swap with its child.
45///
46/// Furthermore, tries to push down projections from nested loop join filters that only depend on
47/// one side of the join. By pushing these projections down, functions that only depend on one side
48/// of the join must be evaluated for the cartesian product of the two sides.
49#[derive(Default, Debug)]
50pub struct ProjectionPushdown {}
51
52impl ProjectionPushdown {
53    #[expect(missing_docs)]
54    pub fn new() -> Self {
55        Self {}
56    }
57}
58
59impl PhysicalOptimizerRule for ProjectionPushdown {
60    fn optimize(
61        &self,
62        plan: Arc<dyn ExecutionPlan>,
63        _config: &ConfigOptions,
64    ) -> Result<Arc<dyn ExecutionPlan>> {
65        let alias_generator = AliasGenerator::new();
66        let plan = plan
67            .transform_up(|plan| match plan.downcast_ref::<NestedLoopJoinExec>() {
68                None => Ok(Transformed::no(plan)),
69                Some(hash_join) => try_push_down_join_filter(
70                    Arc::clone(&plan),
71                    hash_join,
72                    &alias_generator,
73                ),
74            })
75            .map(|t| t.data)?;
76
77        plan.transform_down(remove_unnecessary_projections).data()
78    }
79
80    fn name(&self) -> &str {
81        "ProjectionPushdown"
82    }
83
84    fn schema_check(&self) -> bool {
85        true
86    }
87}
88
89/// Tries to push down parts of the filter.
90///
91/// See [JoinFilterRewriter] for details.
92fn try_push_down_join_filter(
93    original_plan: Arc<dyn ExecutionPlan>,
94    join: &NestedLoopJoinExec,
95    alias_generator: &AliasGenerator,
96) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
97    // Mark joins are currently not supported.
98    if matches!(join.join_type(), JoinType::LeftMark | JoinType::RightMark) {
99        return Ok(Transformed::no(original_plan));
100    }
101
102    let projections = join.projection();
103    let Some(filter) = join.filter() else {
104        return Ok(Transformed::no(original_plan));
105    };
106
107    let original_lhs_length = join.left().schema().fields().len();
108    let original_rhs_length = join.right().schema().fields().len();
109
110    let lhs_rewrite = try_push_down_projection(
111        Arc::clone(&join.right().schema()),
112        Arc::clone(join.left()),
113        JoinSide::Left,
114        filter.clone(),
115        alias_generator,
116    )?;
117    let rhs_rewrite = try_push_down_projection(
118        Arc::clone(&lhs_rewrite.data.0.schema()),
119        Arc::clone(join.right()),
120        JoinSide::Right,
121        lhs_rewrite.data.1,
122        alias_generator,
123    )?;
124    if !lhs_rewrite.transformed && !rhs_rewrite.transformed {
125        return Ok(Transformed::no(original_plan));
126    }
127
128    let join_filter = minimize_join_filter(
129        Arc::clone(rhs_rewrite.data.1.expression()),
130        rhs_rewrite.data.1.column_indices(),
131        lhs_rewrite.data.0.schema().as_ref(),
132        rhs_rewrite.data.0.schema().as_ref(),
133    );
134
135    let new_lhs_length = lhs_rewrite.data.0.schema().fields.len();
136    let projections = match projections.as_ref() {
137        None => match join.join_type() {
138            JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => {
139                // Build projections that ignore the newly projected columns.
140                let mut projections = Vec::new();
141                projections.extend(0..original_lhs_length);
142                projections.extend(new_lhs_length..new_lhs_length + original_rhs_length);
143                projections
144            }
145            JoinType::LeftSemi | JoinType::LeftAnti => {
146                // Only return original left columns
147                let mut projections = Vec::new();
148                projections.extend(0..original_lhs_length);
149                projections
150            }
151            JoinType::RightSemi | JoinType::RightAnti => {
152                // Only return original right columns
153                let mut projections = Vec::new();
154                projections.extend(0..original_rhs_length);
155                projections
156            }
157            _ => unreachable!("Unsupported join type"),
158        },
159        Some(projections) => {
160            let rhs_offset = new_lhs_length - original_lhs_length;
161            projections
162                .iter()
163                .map(|idx| {
164                    if *idx >= original_lhs_length {
165                        idx + rhs_offset
166                    } else {
167                        *idx
168                    }
169                })
170                .collect()
171        }
172    };
173
174    Ok(Transformed::yes(Arc::new(NestedLoopJoinExec::try_new(
175        lhs_rewrite.data.0,
176        rhs_rewrite.data.0,
177        Some(join_filter),
178        join.join_type(),
179        Some(projections),
180    )?)))
181}
182
183/// Tries to push down parts of `expr` into the `join_side`.
184fn try_push_down_projection(
185    other_schema: SchemaRef,
186    plan: Arc<dyn ExecutionPlan>,
187    join_side: JoinSide,
188    join_filter: JoinFilter,
189    alias_generator: &AliasGenerator,
190) -> Result<Transformed<(Arc<dyn ExecutionPlan>, JoinFilter)>> {
191    let expr = Arc::clone(join_filter.expression());
192    let original_plan_schema = plan.schema();
193    let mut rewriter = JoinFilterRewriter::new(
194        join_side,
195        original_plan_schema.as_ref(),
196        join_filter.column_indices().to_vec(),
197        alias_generator,
198    );
199    let new_expr = rewriter.rewrite(expr)?;
200
201    if new_expr.transformed {
202        let new_join_side =
203            ProjectionExec::try_new(rewriter.join_side_projections, plan)?;
204        let new_schema = Arc::clone(&new_join_side.schema());
205
206        let (lhs_schema, rhs_schema) = match join_side {
207            JoinSide::Left => (new_schema, other_schema),
208            JoinSide::Right => (other_schema, new_schema),
209            JoinSide::None => unreachable!("Mark join not supported"),
210        };
211        let intermediate_schema = rewriter
212            .intermediate_column_indices
213            .iter()
214            .map(|ci| match ci.side {
215                JoinSide::Left => Arc::clone(&lhs_schema.fields[ci.index]),
216                JoinSide::Right => Arc::clone(&rhs_schema.fields[ci.index]),
217                JoinSide::None => unreachable!("Mark join not supported"),
218            })
219            .collect::<Fields>();
220
221        let join_filter = JoinFilter::new(
222            new_expr.data,
223            rewriter.intermediate_column_indices,
224            Arc::new(Schema::new(intermediate_schema)),
225        );
226        Ok(Transformed::yes((Arc::new(new_join_side), join_filter)))
227    } else {
228        Ok(Transformed::no((plan, join_filter)))
229    }
230}
231
232/// Creates a new [JoinFilter] and tries to minimize the internal schema.
233///
234/// This could eliminate some columns that were only part of a computation that has been pushed
235/// down. As this computation is now materialized on one side of the join, the original input
236/// columns are not needed anymore.
237fn minimize_join_filter(
238    expr: Arc<dyn PhysicalExpr>,
239    old_column_indices: &[ColumnIndex],
240    lhs_schema: &Schema,
241    rhs_schema: &Schema,
242) -> JoinFilter {
243    let mut used_columns = HashSet::new();
244    expr.apply(|expr| {
245        if let Some(col) = expr.downcast_ref::<Column>() {
246            used_columns.insert(col.index());
247        }
248        Ok(TreeNodeRecursion::Continue)
249    })
250    .expect("Closure cannot fail");
251
252    let new_column_indices = old_column_indices
253        .iter()
254        .enumerate()
255        .filter(|(idx, _)| used_columns.contains(idx))
256        .map(|(_, ci)| ci.clone())
257        .collect::<Vec<_>>();
258    let fields = new_column_indices
259        .iter()
260        .map(|ci| match ci.side {
261            JoinSide::Left => lhs_schema.field(ci.index).clone(),
262            JoinSide::Right => rhs_schema.field(ci.index).clone(),
263            JoinSide::None => unreachable!("Mark join not supported"),
264        })
265        .collect::<Fields>();
266
267    let final_expr = expr
268        .transform_up(|expr| match expr.downcast_ref::<Column>() {
269            None => Ok(Transformed::no(expr)),
270            Some(column) => {
271                let new_idx = used_columns
272                    .iter()
273                    .filter(|idx| **idx < column.index())
274                    .count();
275                let new_column = Column::new(column.name(), new_idx);
276                Ok(Transformed::yes(
277                    Arc::new(new_column) as Arc<dyn PhysicalExpr>
278                ))
279            }
280        })
281        .expect("Closure cannot fail");
282
283    JoinFilter::new(
284        final_expr.data,
285        new_column_indices,
286        Arc::new(Schema::new(fields)),
287    )
288}
289
290/// Implements the push-down machinery.
291///
292/// The rewriter starts at the top of the filter expression and traverses the expression tree. For
293/// each (sub-)expression, the rewriter checks whether it only refers to one side of the join. If
294/// this is never the case, no subexpressions of the filter can be pushed down. If there is a
295/// subexpression that can be computed using only one side of the join, the entire subexpression is
296/// pushed down to the join side.
297struct JoinFilterRewriter<'a> {
298    join_side: JoinSide,
299    join_side_schema: &'a Schema,
300    join_side_projections: Vec<(Arc<dyn PhysicalExpr>, String)>,
301    intermediate_column_indices: Vec<ColumnIndex>,
302    alias_generator: &'a AliasGenerator,
303}
304
305impl<'a> JoinFilterRewriter<'a> {
306    /// Creates a new [JoinFilterRewriter].
307    fn new(
308        join_side: JoinSide,
309        join_side_schema: &'a Schema,
310        column_indices: Vec<ColumnIndex>,
311        alias_generator: &'a AliasGenerator,
312    ) -> Self {
313        let projections = join_side_schema
314            .fields()
315            .iter()
316            .enumerate()
317            .map(|(idx, field)| {
318                (
319                    Arc::new(Column::new(field.name(), idx)) as Arc<dyn PhysicalExpr>,
320                    field.name().to_string(),
321                )
322            })
323            .collect();
324
325        Self {
326            join_side,
327            join_side_schema,
328            join_side_projections: projections,
329            intermediate_column_indices: column_indices,
330            alias_generator,
331        }
332    }
333
334    /// Executes the push-down machinery on `expr`.
335    ///
336    /// See the [JoinFilterRewriter] for further information.
337    fn rewrite(
338        &mut self,
339        expr: Arc<dyn PhysicalExpr>,
340    ) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
341        let depends_on_this_side = self.depends_on_join_side(&expr, self.join_side)?;
342        // We don't push down things that do not depend on this side (other side or no side).
343        if !depends_on_this_side {
344            return Ok(Transformed::no(expr));
345        }
346
347        // Recurse if there is a dependency to both sides or if the entire expression is volatile.
348        let depends_on_other_side =
349            self.depends_on_join_side(&expr, self.join_side.negate())?;
350        if depends_on_other_side || is_volatile(&expr) {
351            return expr.map_children(|expr| self.rewrite(expr));
352        }
353
354        // There is only a dependency on this side.
355
356        // If this expression has no children, we do not push down, as it should already be a column
357        // reference.
358        if expr.children().is_empty() {
359            return Ok(Transformed::no(expr));
360        }
361
362        // Otherwise, we push down a projection.
363        let alias = self.alias_generator.next("join_proj_push_down");
364        let idx = self.create_new_column(alias.clone(), expr)?;
365
366        Ok(Transformed::yes(
367            Arc::new(Column::new(&alias, idx)) as Arc<dyn PhysicalExpr>
368        ))
369    }
370
371    /// Creates a new column in the current join side.
372    fn create_new_column(
373        &mut self,
374        name: String,
375        expr: Arc<dyn PhysicalExpr>,
376    ) -> Result<usize> {
377        // First, add a new projection. The expression must be rewritten, as it is no longer
378        // executed against the filter schema.
379        let new_idx = self.join_side_projections.len();
380        let rewritten_expr = expr.transform_up(|expr| {
381            Ok(match expr.downcast_ref::<Column>() {
382                None => Transformed::no(expr),
383                Some(column) => {
384                    let intermediate_column =
385                        &self.intermediate_column_indices[column.index()];
386                    assert_eq!(intermediate_column.side, self.join_side);
387
388                    let join_side_index = intermediate_column.index;
389                    let field = self.join_side_schema.field(join_side_index);
390                    let new_column = Column::new(field.name(), join_side_index);
391                    Transformed::yes(Arc::new(new_column) as Arc<dyn PhysicalExpr>)
392                }
393            })
394        })?;
395        self.join_side_projections.push((rewritten_expr.data, name));
396
397        // Then, update the column indices
398        let new_intermediate_idx = self.intermediate_column_indices.len();
399        let idx = ColumnIndex {
400            index: new_idx,
401            side: self.join_side,
402        };
403        self.intermediate_column_indices.push(idx);
404
405        Ok(new_intermediate_idx)
406    }
407
408    /// Checks whether the entire expression depends on the given `join_side`.
409    fn depends_on_join_side(
410        &mut self,
411        expr: &Arc<dyn PhysicalExpr>,
412        join_side: JoinSide,
413    ) -> Result<bool> {
414        let mut result = false;
415        expr.apply(|expr| match expr.downcast_ref::<Column>() {
416            None => Ok(TreeNodeRecursion::Continue),
417            Some(c) => {
418                let column_index = &self.intermediate_column_indices[c.index()];
419                if column_index.side == join_side {
420                    result = true;
421                    return Ok(TreeNodeRecursion::Stop);
422                }
423                Ok(TreeNodeRecursion::Continue)
424            }
425        })?;
426
427        Ok(result)
428    }
429}
430
431#[cfg(test)]
432mod test {
433    use super::*;
434    use arrow::datatypes::{DataType, Field, FieldRef, Schema};
435    use datafusion_expr_common::operator::Operator;
436    use datafusion_functions::math::random;
437    use datafusion_physical_expr::ScalarFunctionExpr;
438    use datafusion_physical_expr::expressions::{binary, lit};
439    use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
440    use datafusion_physical_plan::displayable;
441    use datafusion_physical_plan::empty::EmptyExec;
442    use insta::assert_snapshot;
443    use std::sync::Arc;
444
445    #[tokio::test]
446    async fn no_computation_does_not_project() -> Result<()> {
447        let (left_schema, right_schema) = create_simple_schemas();
448        let optimized_plan = run_test(
449            left_schema,
450            right_schema,
451            a_x(),
452            None,
453            a_greater_than_x,
454            JoinType::Inner,
455        )?;
456
457        assert_snapshot!(optimized_plan, @r"
458        NestedLoopJoinExec: join_type=Inner, filter=a@0 > x@1
459          EmptyExec
460          EmptyExec
461        ");
462        Ok(())
463    }
464
465    #[tokio::test]
466    async fn simple_push_down() -> Result<()> {
467        let (left_schema, right_schema) = create_simple_schemas();
468        let optimized_plan = run_test(
469            left_schema,
470            right_schema,
471            a_x(),
472            None,
473            a_plus_one_greater_than_x_plus_one,
474            JoinType::Inner,
475        )?;
476
477        assert_snapshot!(optimized_plan, @r"
478        NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, x@2]
479          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
480            EmptyExec
481          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
482            EmptyExec
483        ");
484        Ok(())
485    }
486
487    #[tokio::test]
488    async fn does_not_push_down_short_circuiting_expressions() -> Result<()> {
489        let (left_schema, right_schema) = create_simple_schemas();
490        let optimized_plan = run_test(
491            left_schema,
492            right_schema,
493            a_x(),
494            None,
495            |schema| {
496                binary(
497                    lit(false),
498                    Operator::And,
499                    a_plus_one_greater_than_x_plus_one(schema)?,
500                    schema,
501                )
502            },
503            JoinType::Inner,
504        )?;
505
506        assert_snapshot!(optimized_plan, @r"
507        NestedLoopJoinExec: join_type=Inner, filter=false AND join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, x@2]
508          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
509            EmptyExec
510          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
511            EmptyExec
512        ");
513        Ok(())
514    }
515
516    #[tokio::test]
517    async fn does_not_push_down_volatile_functions() -> Result<()> {
518        let (left_schema, right_schema) = create_simple_schemas();
519        let optimized_plan = run_test(
520            left_schema,
521            right_schema,
522            a_x(),
523            None,
524            a_plus_rand_greater_than_x,
525            JoinType::Inner,
526        )?;
527
528        assert_snapshot!(optimized_plan, @r"
529        NestedLoopJoinExec: join_type=Inner, filter=a@0 + rand() > x@1
530          EmptyExec
531          EmptyExec
532        ");
533        Ok(())
534    }
535
536    #[tokio::test]
537    async fn complex_schema_push_down() -> Result<()> {
538        let (left_schema, right_schema) = create_complex_schemas();
539
540        let optimized_plan = run_test(
541            left_schema,
542            right_schema,
543            a_b_x_z(),
544            None,
545            a_plus_b_greater_than_x_plus_z,
546            JoinType::Inner,
547        )?;
548
549        assert_snapshot!(optimized_plan, @r"
550        NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0, b@1, c@2, x@4, y@5, z@6]
551          ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, a@0 + b@1 as join_proj_push_down_1]
552            EmptyExec
553          ProjectionExec: expr=[x@0 as x, y@1 as y, z@2 as z, x@0 + z@2 as join_proj_push_down_2]
554            EmptyExec
555        ");
556        Ok(())
557    }
558
559    #[tokio::test]
560    async fn push_down_with_existing_projections() -> Result<()> {
561        let (left_schema, right_schema) = create_complex_schemas();
562
563        let optimized_plan = run_test(
564            left_schema,
565            right_schema,
566            a_b_x_z(),
567            Some(vec![1, 3, 5]), // ("b", "x", "z")
568            a_plus_b_greater_than_x_plus_z,
569            JoinType::Inner,
570        )?;
571
572        assert_snapshot!(optimized_plan, @r"
573        NestedLoopJoinExec: join_type=Inner, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[b@1, x@4, z@6]
574          ProjectionExec: expr=[a@0 as a, b@1 as b, c@2 as c, a@0 + b@1 as join_proj_push_down_1]
575            EmptyExec
576          ProjectionExec: expr=[x@0 as x, y@1 as y, z@2 as z, x@0 + z@2 as join_proj_push_down_2]
577            EmptyExec
578        ");
579        Ok(())
580    }
581
582    #[tokio::test]
583    async fn left_semi_join_projection() -> Result<()> {
584        let (left_schema, right_schema) = create_simple_schemas();
585
586        let left_semi_join_plan = run_test(
587            left_schema.clone(),
588            right_schema.clone(),
589            a_x(),
590            None,
591            a_plus_one_greater_than_x_plus_one,
592            JoinType::LeftSemi,
593        )?;
594
595        assert_snapshot!(left_semi_join_plan, @r"
596        NestedLoopJoinExec: join_type=LeftSemi, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[a@0]
597          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
598            EmptyExec
599          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
600            EmptyExec
601        ");
602        Ok(())
603    }
604
605    #[tokio::test]
606    async fn right_semi_join_projection() -> Result<()> {
607        let (left_schema, right_schema) = create_simple_schemas();
608        let right_semi_join_plan = run_test(
609            left_schema,
610            right_schema,
611            a_x(),
612            None,
613            a_plus_one_greater_than_x_plus_one,
614            JoinType::RightSemi,
615        )?;
616        assert_snapshot!(right_semi_join_plan, @r"
617        NestedLoopJoinExec: join_type=RightSemi, filter=join_proj_push_down_1@0 > join_proj_push_down_2@1, projection=[x@0]
618          ProjectionExec: expr=[a@0 as a, a@0 + 1 as join_proj_push_down_1]
619            EmptyExec
620          ProjectionExec: expr=[x@0 as x, x@0 + 1 as join_proj_push_down_2]
621            EmptyExec
622        ");
623        Ok(())
624    }
625
626    fn run_test(
627        left_schema: Schema,
628        right_schema: Schema,
629        column_indices: Vec<ColumnIndex>,
630        existing_projections: Option<Vec<usize>>,
631        filter_expr_builder: impl FnOnce(&Schema) -> Result<Arc<dyn PhysicalExpr>>,
632        join_type: JoinType,
633    ) -> Result<String> {
634        let left = Arc::new(EmptyExec::new(Arc::new(left_schema.clone())));
635        let right = Arc::new(EmptyExec::new(Arc::new(right_schema.clone())));
636
637        let join_fields: Vec<_> = column_indices
638            .iter()
639            .map(|ci| match ci.side {
640                JoinSide::Left => left_schema.field(ci.index).clone(),
641                JoinSide::Right => right_schema.field(ci.index).clone(),
642                JoinSide::None => unreachable!(),
643            })
644            .collect();
645        let join_schema = Arc::new(Schema::new(join_fields));
646
647        let filter_expr = filter_expr_builder(join_schema.as_ref())?;
648
649        let join_filter = JoinFilter::new(filter_expr, column_indices, join_schema);
650
651        let join = NestedLoopJoinExec::try_new(
652            left,
653            right,
654            Some(join_filter),
655            &join_type,
656            existing_projections,
657        )?;
658
659        let optimizer = ProjectionPushdown::new();
660        let optimized_plan = optimizer.optimize(Arc::new(join), &Default::default())?;
661
662        let displayable_plan = displayable(optimized_plan.as_ref()).indent(false);
663        Ok(displayable_plan.to_string())
664    }
665
666    fn create_simple_schemas() -> (Schema, Schema) {
667        let left_schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
668        let right_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]);
669
670        (left_schema, right_schema)
671    }
672
673    fn create_complex_schemas() -> (Schema, Schema) {
674        let left_schema = Schema::new(vec![
675            Field::new("a", DataType::Int32, false),
676            Field::new("b", DataType::Int32, false),
677            Field::new("c", DataType::Int32, false),
678        ]);
679
680        let right_schema = Schema::new(vec![
681            Field::new("x", DataType::Int32, false),
682            Field::new("y", DataType::Int32, false),
683            Field::new("z", DataType::Int32, false),
684        ]);
685
686        (left_schema, right_schema)
687    }
688
689    fn a_x() -> Vec<ColumnIndex> {
690        vec![
691            ColumnIndex {
692                index: 0,
693                side: JoinSide::Left,
694            },
695            ColumnIndex {
696                index: 0,
697                side: JoinSide::Right,
698            },
699        ]
700    }
701
702    fn a_b_x_z() -> Vec<ColumnIndex> {
703        vec![
704            ColumnIndex {
705                index: 0,
706                side: JoinSide::Left,
707            },
708            ColumnIndex {
709                index: 1,
710                side: JoinSide::Left,
711            },
712            ColumnIndex {
713                index: 0,
714                side: JoinSide::Right,
715            },
716            ColumnIndex {
717                index: 2,
718                side: JoinSide::Right,
719            },
720        ]
721    }
722
723    fn a_plus_one_greater_than_x_plus_one(
724        join_schema: &Schema,
725    ) -> Result<Arc<dyn PhysicalExpr>> {
726        let left_expr = binary(
727            Arc::new(Column::new("a", 0)),
728            Operator::Plus,
729            lit(1),
730            join_schema,
731        )?;
732        let right_expr = binary(
733            Arc::new(Column::new("x", 1)),
734            Operator::Plus,
735            lit(1),
736            join_schema,
737        )?;
738        binary(left_expr, Operator::Gt, right_expr, join_schema)
739    }
740
741    fn a_plus_rand_greater_than_x(join_schema: &Schema) -> Result<Arc<dyn PhysicalExpr>> {
742        let left_expr = binary(
743            Arc::new(Column::new("a", 0)),
744            Operator::Plus,
745            Arc::new(ScalarFunctionExpr::new(
746                "rand",
747                random(),
748                vec![],
749                FieldRef::new(Field::new("out", DataType::Float64, false)),
750                Arc::new(ConfigOptions::default()),
751            )),
752            join_schema,
753        )?;
754        let right_expr = Arc::new(Column::new("x", 1));
755        binary(left_expr, Operator::Gt, right_expr, join_schema)
756    }
757
758    fn a_greater_than_x(join_schema: &Schema) -> Result<Arc<dyn PhysicalExpr>> {
759        binary(
760            Arc::new(Column::new("a", 0)),
761            Operator::Gt,
762            Arc::new(Column::new("x", 1)),
763            join_schema,
764        )
765    }
766
767    fn a_plus_b_greater_than_x_plus_z(
768        join_schema: &Schema,
769    ) -> Result<Arc<dyn PhysicalExpr>> {
770        let lhs = binary(
771            Arc::new(Column::new("a", 0)),
772            Operator::Plus,
773            Arc::new(Column::new("b", 1)),
774            join_schema,
775        )?;
776        let rhs = binary(
777            Arc::new(Column::new("x", 2)),
778            Operator::Plus,
779            Arc::new(Column::new("z", 3)),
780            join_schema,
781        )?;
782        binary(lhs, Operator::Gt, rhs, join_schema)
783    }
784}