datafusion_optimizer/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions leveraged by the query optimizer rules
19
20use std::collections::{BTreeSet, HashMap, HashSet};
21
22use crate::{OptimizerConfig, OptimizerRule};
23
24use crate::analyzer::type_coercion::TypeCoercionRewriter;
25use arrow::array::{new_null_array, Array, RecordBatch};
26use arrow::datatypes::{DataType, Field, Schema};
27use datafusion_common::cast::as_boolean_array;
28use datafusion_common::tree_node::{TransformedResult, TreeNode};
29use datafusion_common::{Column, DFSchema, Result, ScalarValue};
30use datafusion_expr::execution_props::ExecutionProps;
31use datafusion_expr::expr_rewriter::replace_col;
32use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr};
33use datafusion_physical_expr::create_physical_expr;
34use log::{debug, trace};
35use std::sync::Arc;
36
37/// Re-export of `NamesPreserver` for backwards compatibility,
38/// as it was initially placed here and then moved elsewhere.
39pub use datafusion_expr::expr_rewriter::NamePreserver;
40
41/// Convenience rule for writing optimizers: recursively invoke
42/// optimize on plan's children and then return a node of the same
43/// type. Useful for optimizer rules which want to leave the type
44/// of plan unchanged but still apply to the children.
45/// This also handles the case when the `plan` is a [`LogicalPlan::Explain`].
46///
47/// Returning `Ok(None)` indicates that the plan can't be optimized by the `optimizer`.
48#[deprecated(
49    since = "40.0.0",
50    note = "please use OptimizerRule::apply_order with ApplyOrder::BottomUp instead"
51)]
52pub fn optimize_children(
53    optimizer: &impl OptimizerRule,
54    plan: &LogicalPlan,
55    config: &dyn OptimizerConfig,
56) -> Result<Option<LogicalPlan>> {
57    let mut new_inputs = Vec::with_capacity(plan.inputs().len());
58    let mut plan_is_changed = false;
59    for input in plan.inputs() {
60        if optimizer.supports_rewrite() {
61            let new_input = optimizer.rewrite(input.clone(), config)?;
62            plan_is_changed = plan_is_changed || new_input.transformed;
63            new_inputs.push(new_input.data);
64        } else {
65            #[allow(deprecated)]
66            let new_input = optimizer.try_optimize(input, config)?;
67            plan_is_changed = plan_is_changed || new_input.is_some();
68            new_inputs.push(new_input.unwrap_or_else(|| input.clone()))
69        }
70    }
71    if plan_is_changed {
72        let exprs = plan.expressions();
73        plan.with_new_exprs(exprs, new_inputs).map(Some)
74    } else {
75        Ok(None)
76    }
77}
78
79/// Returns true if `expr` contains all columns in `schema_cols`
80pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet<Column>) -> bool {
81    let column_refs = expr.column_refs();
82    // note can't use HashSet::intersect because of different types (owned vs References)
83    schema_cols
84        .iter()
85        .filter(|c| column_refs.contains(c))
86        .count()
87        == column_refs.len()
88}
89
90pub(crate) fn replace_qualified_name(
91    expr: Expr,
92    cols: &BTreeSet<Column>,
93    subquery_alias: &str,
94) -> Result<Expr> {
95    let alias_cols: Vec<Column> = cols
96        .iter()
97        .map(|col| Column::new(Some(subquery_alias), &col.name))
98        .collect();
99    let replace_map: HashMap<&Column, &Column> =
100        cols.iter().zip(alias_cols.iter()).collect();
101
102    replace_col(expr, &replace_map)
103}
104
105/// Log the plan in debug/tracing mode after some part of the optimizer runs
106pub fn log_plan(description: &str, plan: &LogicalPlan) {
107    debug!("{description}:\n{}\n", plan.display_indent());
108    trace!("{description}::\n{}\n", plan.display_indent_schema());
109}
110
111/// Determine whether a predicate can restrict NULLs. e.g.
112/// `c0 > 8` return true;
113/// `c0 IS NULL` return false.
114pub fn is_restrict_null_predicate<'a>(
115    predicate: Expr,
116    join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
117) -> Result<bool> {
118    if matches!(predicate, Expr::Column(_)) {
119        return Ok(true);
120    }
121
122    static DUMMY_COL_NAME: &str = "?";
123    let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]);
124    let input_schema = DFSchema::try_from(schema.clone())?;
125    let column = new_null_array(&DataType::Null, 1);
126    let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![column])?;
127    let execution_props = ExecutionProps::default();
128    let null_column = Column::from_name(DUMMY_COL_NAME);
129
130    let join_cols_to_replace = join_cols_of_predicate
131        .into_iter()
132        .map(|column| (column, &null_column))
133        .collect::<HashMap<_, _>>();
134
135    let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
136    let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
137    let phys_expr =
138        create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?;
139
140    let result_type = phys_expr.data_type(&schema)?;
141    if !matches!(&result_type, DataType::Boolean) {
142        return Ok(false);
143    }
144
145    // If result is single `true`, return false;
146    // If result is single `NULL` or `false`, return true;
147    Ok(match phys_expr.evaluate(&input_batch)? {
148        ColumnarValue::Array(array) => {
149            if array.len() == 1 {
150                let boolean_array = as_boolean_array(&array)?;
151                boolean_array.is_null(0) || !boolean_array.value(0)
152            } else {
153                false
154            }
155        }
156        ColumnarValue::Scalar(scalar) => matches!(
157            scalar,
158            ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
159        ),
160    })
161}
162
163fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
164    let mut expr_rewrite = TypeCoercionRewriter { schema };
165    expr.rewrite(&mut expr_rewrite).data()
166}
167
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator};
172
173    #[test]
174    fn expr_is_restrict_null_predicate() -> Result<()> {
175        let test_cases = vec![
176            // a
177            (col("a"), true),
178            // a IS NULL
179            (is_null(col("a")), false),
180            // a IS NOT NULL
181            (Expr::IsNotNull(Box::new(col("a"))), true),
182            // a = NULL
183            (
184                binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)),
185                true,
186            ),
187            // a > 8
188            (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
189            // a <= 8
190            (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
191            // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END
192            (
193                case(col("a"))
194                    .when(lit(1i64), lit(true))
195                    .when(lit(0i64), lit(false))
196                    .otherwise(lit(ScalarValue::Null))?,
197                true,
198            ),
199            // CASE a WHEN 1 THEN true ELSE false END
200            (
201                case(col("a"))
202                    .when(lit(1i64), lit(true))
203                    .otherwise(lit(false))?,
204                true,
205            ),
206            // CASE a WHEN 0 THEN false ELSE true END
207            (
208                case(col("a"))
209                    .when(lit(0i64), lit(false))
210                    .otherwise(lit(true))?,
211                false,
212            ),
213            // (CASE a WHEN 0 THEN false ELSE true END) OR false
214            (
215                binary_expr(
216                    case(col("a"))
217                        .when(lit(0i64), lit(false))
218                        .otherwise(lit(true))?,
219                    Operator::Or,
220                    lit(false),
221                ),
222                false,
223            ),
224            // (CASE a WHEN 0 THEN true ELSE false END) OR false
225            (
226                binary_expr(
227                    case(col("a"))
228                        .when(lit(0i64), lit(true))
229                        .otherwise(lit(false))?,
230                    Operator::Or,
231                    lit(false),
232                ),
233                true,
234            ),
235            // a IN (1, 2, 3)
236            (
237                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false),
238                true,
239            ),
240            // a NOT IN (1, 2, 3)
241            (
242                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
243                true,
244            ),
245            // a IN (NULL)
246            (
247                in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false),
248                true,
249            ),
250            // a NOT IN (NULL)
251            (
252                in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true),
253                true,
254            ),
255        ];
256
257        let column_a = Column::from_name("a");
258        for (predicate, expected) in test_cases {
259            let join_cols_of_predicate = std::iter::once(&column_a);
260            let actual =
261                is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?;
262            assert_eq!(actual, expected, "{}", predicate);
263        }
264
265        Ok(())
266    }
267}