Skip to main content

datafusion_optimizer/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions leveraged by the query optimizer rules
19
20use std::collections::{BTreeSet, HashMap, HashSet};
21
22use crate::analyzer::type_coercion::TypeCoercionRewriter;
23use arrow::array::{Array, RecordBatch, new_null_array};
24use arrow::datatypes::{DataType, Field, Schema};
25use datafusion_common::TableReference;
26use datafusion_common::cast::as_boolean_array;
27use datafusion_common::tree_node::{TransformedResult, TreeNode};
28use datafusion_common::{Column, DFSchema, Result, ScalarValue};
29use datafusion_expr::execution_props::ExecutionProps;
30use datafusion_expr::expr_rewriter::replace_col;
31use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan};
32use datafusion_physical_expr::create_physical_expr;
33use log::{debug, trace};
34use std::sync::Arc;
35
36/// Re-export of `NamesPreserver` for backwards compatibility,
37/// as it was initially placed here and then moved elsewhere.
38pub use datafusion_expr::expr_rewriter::NamePreserver;
39
40/// Returns true if `expr` contains all columns in `schema_cols`
41pub(crate) fn has_all_column_refs(
42    expr: &Expr,
43    schema_cols: &HashSet<ColumnReference>,
44) -> bool {
45    let column_refs = expr.column_refs();
46    // note can't use HashSet::intersect because of different types (owned vs References)
47    column_refs
48        .iter()
49        .filter(|c| {
50            schema_cols.contains(&ColumnReference::new(c.relation.as_ref(), c.name()))
51        })
52        .count()
53        == column_refs.len()
54}
55
56pub(crate) fn replace_qualified_name(
57    expr: Expr,
58    cols: &BTreeSet<Column>,
59    subquery_alias: &str,
60) -> Result<Expr> {
61    let alias_cols: Vec<Column> = cols
62        .iter()
63        .map(|col| Column::new(Some(subquery_alias), &col.name))
64        .collect();
65    let replace_map: HashMap<&Column, &Column> =
66        cols.iter().zip(alias_cols.iter()).collect();
67
68    replace_col(expr, &replace_map)
69}
70
71/// Column reference to avoid copying string around
72#[derive(PartialEq, Eq, Hash, Debug)]
73pub(crate) struct ColumnReference<'a> {
74    pub relation: Option<&'a TableReference>,
75    pub name: &'a str,
76}
77
78impl<'a> ColumnReference<'a> {
79    pub fn new(relation: Option<&'a TableReference>, name: &'a str) -> Self {
80        Self { relation, name }
81    }
82
83    pub fn new_unqualified(name: &'a str) -> Self {
84        Self {
85            relation: None,
86            name,
87        }
88    }
89}
90
91/// Returns references to all columns in the schema
92pub(crate) fn schema_columns<'a>(schema: &'a DFSchema) -> HashSet<ColumnReference<'a>> {
93    schema
94        .iter()
95        .flat_map(|(qualifier, field)| {
96            [
97                ColumnReference::new(qualifier, field.name()),
98                // we need to push down filter using unqualified column as well
99                ColumnReference::new_unqualified(field.name()),
100            ]
101        })
102        .collect::<HashSet<_>>()
103}
104
105/// Log the plan in debug/tracing mode after some part of the optimizer runs
106pub fn log_plan(description: &str, plan: &LogicalPlan) {
107    debug!("{description}:\n{}\n", plan.display_indent());
108    trace!("{description}::\n{}\n", plan.display_indent_schema());
109}
110
111/// Determine whether a predicate can restrict NULLs. e.g.
112/// `c0 > 8` return true;
113/// `c0 IS NULL` return false.
114pub fn is_restrict_null_predicate<'a>(
115    predicate: Expr,
116    join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
117) -> Result<bool> {
118    if matches!(predicate, Expr::Column(_)) {
119        return Ok(true);
120    }
121
122    // If result is single `true`, return false;
123    // If result is single `NULL` or `false`, return true;
124    Ok(
125        match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
126            ColumnarValue::Array(array) => {
127                if array.len() == 1 {
128                    let boolean_array = as_boolean_array(&array)?;
129                    boolean_array.is_null(0) || !boolean_array.value(0)
130                } else {
131                    false
132                }
133            }
134            ColumnarValue::Scalar(scalar) => matches!(
135                scalar,
136                ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
137            ),
138        },
139    )
140}
141
142/// Determines if an expression will always evaluate to null.
143/// `c0 + 8` return true
144/// `c0 IS NULL` return false
145/// `CASE WHEN c0 > 1 then 0 else 1` return false
146pub fn evaluates_to_null<'a>(
147    predicate: Expr,
148    null_columns: impl IntoIterator<Item = &'a Column>,
149) -> Result<bool> {
150    if matches!(predicate, Expr::Column(_)) {
151        return Ok(true);
152    }
153
154    Ok(
155        match evaluate_expr_with_null_column(predicate, null_columns)? {
156            ColumnarValue::Array(_) => false,
157            ColumnarValue::Scalar(scalar) => scalar.is_null(),
158        },
159    )
160}
161
162fn evaluate_expr_with_null_column<'a>(
163    predicate: Expr,
164    null_columns: impl IntoIterator<Item = &'a Column>,
165) -> Result<ColumnarValue> {
166    static DUMMY_COL_NAME: &str = "?";
167    let schema = Arc::new(Schema::new(vec![Field::new(
168        DUMMY_COL_NAME,
169        DataType::Null,
170        true,
171    )]));
172    let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
173    let column = new_null_array(&DataType::Null, 1);
174    let input_batch = RecordBatch::try_new(schema, vec![column])?;
175    let execution_props = ExecutionProps::default();
176    let null_column = Column::from_name(DUMMY_COL_NAME);
177
178    let join_cols_to_replace = null_columns
179        .into_iter()
180        .map(|column| (column, &null_column))
181        .collect::<HashMap<_, _>>();
182
183    let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
184    let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
185    create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
186        .evaluate(&input_batch)
187}
188
189fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
190    let mut expr_rewrite = TypeCoercionRewriter { schema };
191    expr.rewrite(&mut expr_rewrite).data()
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use datafusion_expr::{Operator, binary_expr, case, col, in_list, is_null, lit};
198
199    #[test]
200    fn expr_is_restrict_null_predicate() -> Result<()> {
201        let test_cases = vec![
202            // a
203            (col("a"), true),
204            // a IS NULL
205            (is_null(col("a")), false),
206            // a IS NOT NULL
207            (Expr::IsNotNull(Box::new(col("a"))), true),
208            // a = NULL
209            (
210                binary_expr(
211                    col("a"),
212                    Operator::Eq,
213                    Expr::Literal(ScalarValue::Null, None),
214                ),
215                true,
216            ),
217            // a > 8
218            (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
219            // a <= 8
220            (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
221            // CASE a WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END
222            (
223                case(col("a"))
224                    .when(lit(1i64), lit(true))
225                    .when(lit(0i64), lit(false))
226                    .otherwise(lit(ScalarValue::Null))?,
227                true,
228            ),
229            // CASE a WHEN 1 THEN true ELSE false END
230            (
231                case(col("a"))
232                    .when(lit(1i64), lit(true))
233                    .otherwise(lit(false))?,
234                true,
235            ),
236            // CASE a WHEN 0 THEN false ELSE true END
237            (
238                case(col("a"))
239                    .when(lit(0i64), lit(false))
240                    .otherwise(lit(true))?,
241                false,
242            ),
243            // (CASE a WHEN 0 THEN false ELSE true END) OR false
244            (
245                binary_expr(
246                    case(col("a"))
247                        .when(lit(0i64), lit(false))
248                        .otherwise(lit(true))?,
249                    Operator::Or,
250                    lit(false),
251                ),
252                false,
253            ),
254            // (CASE a WHEN 0 THEN true ELSE false END) OR false
255            (
256                binary_expr(
257                    case(col("a"))
258                        .when(lit(0i64), lit(true))
259                        .otherwise(lit(false))?,
260                    Operator::Or,
261                    lit(false),
262                ),
263                true,
264            ),
265            // a IN (1, 2, 3)
266            (
267                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false),
268                true,
269            ),
270            // a NOT IN (1, 2, 3)
271            (
272                in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
273                true,
274            ),
275            // a IN (NULL)
276            (
277                in_list(
278                    col("a"),
279                    vec![Expr::Literal(ScalarValue::Null, None)],
280                    false,
281                ),
282                true,
283            ),
284            // a NOT IN (NULL)
285            (
286                in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true),
287                true,
288            ),
289        ];
290
291        let column_a = Column::from_name("a");
292        for (predicate, expected) in test_cases {
293            let join_cols_of_predicate = std::iter::once(&column_a);
294            let actual =
295                is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?;
296            assert_eq!(actual, expected, "{predicate}");
297        }
298
299        Ok(())
300    }
301}