1use std::collections::{BTreeSet, HashMap, HashSet};
21
22use crate::analyzer::type_coercion::TypeCoercionRewriter;
23use arrow::array::{new_null_array, Array, RecordBatch};
24use arrow::datatypes::{DataType, Field, Schema};
25use datafusion_common::cast::as_boolean_array;
26use datafusion_common::tree_node::{TransformedResult, TreeNode};
27use datafusion_common::{Column, DFSchema, Result, ScalarValue};
28use datafusion_expr::execution_props::ExecutionProps;
29use datafusion_expr::expr_rewriter::replace_col;
30use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr};
31use datafusion_physical_expr::create_physical_expr;
32use log::{debug, trace};
33use std::sync::Arc;
34
35pub use datafusion_expr::expr_rewriter::NamePreserver;
38
39pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet<Column>) -> bool {
41 let column_refs = expr.column_refs();
42 schema_cols
44 .iter()
45 .filter(|c| column_refs.contains(c))
46 .count()
47 == column_refs.len()
48}
49
50pub(crate) fn replace_qualified_name(
51 expr: Expr,
52 cols: &BTreeSet<Column>,
53 subquery_alias: &str,
54) -> Result<Expr> {
55 let alias_cols: Vec<Column> = cols
56 .iter()
57 .map(|col| Column::new(Some(subquery_alias), &col.name))
58 .collect();
59 let replace_map: HashMap<&Column, &Column> =
60 cols.iter().zip(alias_cols.iter()).collect();
61
62 replace_col(expr, &replace_map)
63}
64
65pub fn log_plan(description: &str, plan: &LogicalPlan) {
67 debug!("{description}:\n{}\n", plan.display_indent());
68 trace!("{description}::\n{}\n", plan.display_indent_schema());
69}
70
71pub fn is_restrict_null_predicate<'a>(
75 predicate: Expr,
76 join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
77) -> Result<bool> {
78 if matches!(predicate, Expr::Column(_)) {
79 return Ok(true);
80 }
81
82 Ok(
85 match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
86 ColumnarValue::Array(array) => {
87 if array.len() == 1 {
88 let boolean_array = as_boolean_array(&array)?;
89 boolean_array.is_null(0) || !boolean_array.value(0)
90 } else {
91 false
92 }
93 }
94 ColumnarValue::Scalar(scalar) => matches!(
95 scalar,
96 ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
97 ),
98 },
99 )
100}
101
102pub fn evaluates_to_null<'a>(
107 predicate: Expr,
108 null_columns: impl IntoIterator<Item = &'a Column>,
109) -> Result<bool> {
110 if matches!(predicate, Expr::Column(_)) {
111 return Ok(true);
112 }
113
114 Ok(
115 match evaluate_expr_with_null_column(predicate, null_columns)? {
116 ColumnarValue::Array(_) => false,
117 ColumnarValue::Scalar(scalar) => scalar.is_null(),
118 },
119 )
120}
121
122fn evaluate_expr_with_null_column<'a>(
123 predicate: Expr,
124 null_columns: impl IntoIterator<Item = &'a Column>,
125) -> Result<ColumnarValue> {
126 static DUMMY_COL_NAME: &str = "?";
127 let schema = Arc::new(Schema::new(vec![Field::new(
128 DUMMY_COL_NAME,
129 DataType::Null,
130 true,
131 )]));
132 let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
133 let column = new_null_array(&DataType::Null, 1);
134 let input_batch = RecordBatch::try_new(schema, vec![column])?;
135 let execution_props = ExecutionProps::default();
136 let null_column = Column::from_name(DUMMY_COL_NAME);
137
138 let join_cols_to_replace = null_columns
139 .into_iter()
140 .map(|column| (column, &null_column))
141 .collect::<HashMap<_, _>>();
142
143 let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
144 let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
145 create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
146 .evaluate(&input_batch)
147}
148
149fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
150 let mut expr_rewrite = TypeCoercionRewriter { schema };
151 expr.rewrite(&mut expr_rewrite).data()
152}
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157 use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator};
158
159 #[test]
160 fn expr_is_restrict_null_predicate() -> Result<()> {
161 let test_cases = vec![
162 (col("a"), true),
164 (is_null(col("a")), false),
166 (Expr::IsNotNull(Box::new(col("a"))), true),
168 (
170 binary_expr(
171 col("a"),
172 Operator::Eq,
173 Expr::Literal(ScalarValue::Null, None),
174 ),
175 true,
176 ),
177 (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
179 (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
181 (
183 case(col("a"))
184 .when(lit(1i64), lit(true))
185 .when(lit(0i64), lit(false))
186 .otherwise(lit(ScalarValue::Null))?,
187 true,
188 ),
189 (
191 case(col("a"))
192 .when(lit(1i64), lit(true))
193 .otherwise(lit(false))?,
194 true,
195 ),
196 (
198 case(col("a"))
199 .when(lit(0i64), lit(false))
200 .otherwise(lit(true))?,
201 false,
202 ),
203 (
205 binary_expr(
206 case(col("a"))
207 .when(lit(0i64), lit(false))
208 .otherwise(lit(true))?,
209 Operator::Or,
210 lit(false),
211 ),
212 false,
213 ),
214 (
216 binary_expr(
217 case(col("a"))
218 .when(lit(0i64), lit(true))
219 .otherwise(lit(false))?,
220 Operator::Or,
221 lit(false),
222 ),
223 true,
224 ),
225 (
227 in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false),
228 true,
229 ),
230 (
232 in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
233 true,
234 ),
235 (
237 in_list(
238 col("a"),
239 vec![Expr::Literal(ScalarValue::Null, None)],
240 false,
241 ),
242 true,
243 ),
244 (
246 in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true),
247 true,
248 ),
249 ];
250
251 let column_a = Column::from_name("a");
252 for (predicate, expected) in test_cases {
253 let join_cols_of_predicate = std::iter::once(&column_a);
254 let actual =
255 is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?;
256 assert_eq!(actual, expected, "{predicate}");
257 }
258
259 Ok(())
260 }
261}