datafusion_optimizer/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions leveraged by the query optimizer rules
19
20use std::collections::{BTreeSet, HashMap, HashSet};
21
22use crate::analyzer::type_coercion::TypeCoercionRewriter;
23use arrow::array::{new_null_array, Array, RecordBatch};
24use arrow::datatypes::{DataType, Field, Schema};
25use datafusion_common::cast::as_boolean_array;
26use datafusion_common::tree_node::{TransformedResult, TreeNode};
27use datafusion_common::{Column, DFSchema, Result, ScalarValue};
28use datafusion_expr::execution_props::ExecutionProps;
29use datafusion_expr::expr_rewriter::replace_col;
30use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr};
31use datafusion_physical_expr::create_physical_expr;
32use log::{debug, trace};
33use std::sync::Arc;
34
35/// Re-export of `NamesPreserver` for backwards compatibility,
36/// as it was initially placed here and then moved elsewhere.
37pub use datafusion_expr::expr_rewriter::NamePreserver;
38
39/// Returns true if `expr` contains all columns in `schema_cols`
40pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet<Column>) -> bool {
41    let column_refs = expr.column_refs();
42    // note can't use HashSet::intersect because of different types (owned vs References)
43    schema_cols
44        .iter()
45        .filter(|c| column_refs.contains(c))
46        .count()
47        == column_refs.len()
48}
49
50pub(crate) fn replace_qualified_name(
51    expr: Expr,
52    cols: &BTreeSet<Column>,
53    subquery_alias: &str,
54) -> Result<Expr> {
55    let alias_cols: Vec<Column> = cols
56        .iter()
57        .map(|col| Column::new(Some(subquery_alias), &col.name))
58        .collect();
59    let replace_map: HashMap<&Column, &Column> =
60        cols.iter().zip(alias_cols.iter()).collect();
61
62    replace_col(expr, &replace_map)
63}
64
65/// Log the plan in debug/tracing mode after some part of the optimizer runs
66pub fn log_plan(description: &str, plan: &LogicalPlan) {
67    debug!("{description}:\n{}\n", plan.display_indent());
68    trace!("{description}::\n{}\n", plan.display_indent_schema());
69}
70
71/// Determine whether a predicate can restrict NULLs. e.g.
72/// `c0 > 8` return true;
73/// `c0 IS NULL` return false.
74pub fn is_restrict_null_predicate<'a>(
75    predicate: Expr,
76    join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
77) -> Result<bool> {
78    if matches!(predicate, Expr::Column(_)) {
79        return Ok(true);
80    }
81
82    // If result is single `true`, return false;
83    // If result is single `NULL` or `false`, return true;
84    Ok(
85        match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
86            ColumnarValue::Array(array) => {
87                if array.len() == 1 {
88                    let boolean_array = as_boolean_array(&array)?;
89                    boolean_array.is_null(0) || !boolean_array.value(0)
90                } else {
91                    false
92                }
93            }
94            ColumnarValue::Scalar(scalar) => matches!(
95                scalar,
96                ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
97            ),
98        },
99    )
100}
101
102/// Determines if an expression will always evaluate to null.
103/// `c0 + 8` return true
104/// `c0 IS NULL` return false
105/// `CASE WHEN c0 > 1 then 0 else 1` return false
106pub fn evaluates_to_null<'a>(
107    predicate: Expr,
108    null_columns: impl IntoIterator<Item = &'a Column>,
109) -> Result<bool> {
110    if matches!(predicate, Expr::Column(_)) {
111        return Ok(true);
112    }
113
114    Ok(
115        match evaluate_expr_with_null_column(predicate, null_columns)? {
116            ColumnarValue::Array(_) => false,
117            ColumnarValue::Scalar(scalar) => scalar.is_null(),
118        },
119    )
120}
121
122fn evaluate_expr_with_null_column<'a>(
123    predicate: Expr,
124    null_columns: impl IntoIterator<Item = &'a Column>,
125) -> Result<ColumnarValue> {
126    static DUMMY_COL_NAME: &str = "?";
127    let schema = Arc::new(Schema::new(vec![Field::new(
128        DUMMY_COL_NAME,
129        DataType::Null,
130        true,
131    )]));
132    let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
133    let column = new_null_array(&DataType::Null, 1);
134    let input_batch = RecordBatch::try_new(schema, vec![column])?;
135    let execution_props = ExecutionProps::default();
136    let null_column = Column::from_name(DUMMY_COL_NAME);
137
138    let join_cols_to_replace = null_columns
139        .into_iter()
140        .map(|column| (column, &null_column))
141        .collect::<HashMap<_, _>>();
142
143    let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
144    let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
145    create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
146        .evaluate(&input_batch)
147}
148
149fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
150    let mut expr_rewrite = TypeCoercionRewriter { schema };
151    expr.rewrite(&mut expr_rewrite).data()
152}
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157    use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator};
158
159    #[test]
160    fn expr_is_restrict_null_predicate() -> Result<()> {
161        let test_cases = vec![
162            // a
163            (col("a"), true),
164            // a IS NULL
165            (is_null(col("a")), false),
166            // a IS NOT NULL
167            (Expr::IsNotNull(Box::new(col("a"))), true),
168            // a = NULL
169            (
170                binary_expr(
171                    col("a"),
172                    Operator::Eq,
173                    Expr::Literal(ScalarValue::Null, None),
174                ),
175                true,
176            ),
177            // a > 8
178            (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
179            // a <= 8
180            (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
181            // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END
182            (
183                case(col("a"))
184                    .when(lit(1i64), lit(true))
185                    .when(lit(0i64), lit(false))
186                    .otherwise(lit(ScalarValue::Null))?,
187                true,
188            ),
189            // CASE a WHEN 1 THEN true ELSE false END
190            (
191                case(col("a"))
192                    .when(lit(1i64), lit(true))
193                    .otherwise(lit(false))?,
194                true,
195            ),
196            // CASE a WHEN 0 THEN false ELSE true END
197            (
198                case(col("a"))
199                    .when(lit(0i64), lit(false))
200                    .otherwise(lit(true))?,
201                false,
202            ),
203            // (CASE a WHEN 0 THEN false ELSE true END) OR false
204            (
205                binary_expr(
206                    case(col("a"))
207                        .when(lit(0i64), lit(false))
208                        .otherwise(lit(true))?,
209                    Operator::Or,
210                    lit(false),
211                ),
212                false,
213            ),
214            // (CASE a WHEN 0 THEN true ELSE false END) OR false
215            (
216                binary_expr(
217                    case(col("a"))
218                        .when(lit(0i64), lit(true))
219                        .otherwise(lit(false))?,
220                    Operator::Or,
221                    lit(false),
222                ),
223                true,
224            ),
225            // a IN (1, 2, 3)
226            (
227                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false),
228                true,
229            ),
230            // a NOT IN (1, 2, 3)
231            (
232                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
233                true,
234            ),
235            // a IN (NULL)
236            (
237                in_list(
238                    col("a"),
239                    vec![Expr::Literal(ScalarValue::Null, None)],
240                    false,
241                ),
242                true,
243            ),
244            // a NOT IN (NULL)
245            (
246                in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true),
247                true,
248            ),
249        ];
250
251        let column_a = Column::from_name("a");
252        for (predicate, expected) in test_cases {
253            let join_cols_of_predicate = std::iter::once(&column_a);
254            let actual =
255                is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?;
256            assert_eq!(actual, expected, "{predicate}");
257        }
258
259        Ok(())
260    }
261}