1use std::collections::{BTreeSet, HashMap, HashSet};
21
22use crate::analyzer::type_coercion::TypeCoercionRewriter;
23use arrow::array::{Array, RecordBatch, new_null_array};
24use arrow::datatypes::{DataType, Field, Schema};
25use datafusion_common::TableReference;
26use datafusion_common::cast::as_boolean_array;
27use datafusion_common::tree_node::{TransformedResult, TreeNode};
28use datafusion_common::{Column, DFSchema, Result, ScalarValue};
29use datafusion_expr::execution_props::ExecutionProps;
30use datafusion_expr::expr_rewriter::replace_col;
31use datafusion_expr::{ColumnarValue, Expr, logical_plan::LogicalPlan};
32use datafusion_physical_expr::create_physical_expr;
33use log::{debug, trace};
34use std::sync::Arc;
35
36pub use datafusion_expr::expr_rewriter::NamePreserver;
39
40pub(crate) fn has_all_column_refs(
42 expr: &Expr,
43 schema_cols: &HashSet<ColumnReference>,
44) -> bool {
45 let column_refs = expr.column_refs();
46 column_refs
48 .iter()
49 .filter(|c| {
50 schema_cols.contains(&ColumnReference::new(c.relation.as_ref(), c.name()))
51 })
52 .count()
53 == column_refs.len()
54}
55
56pub(crate) fn replace_qualified_name(
57 expr: Expr,
58 cols: &BTreeSet<Column>,
59 subquery_alias: &str,
60) -> Result<Expr> {
61 let alias_cols: Vec<Column> = cols
62 .iter()
63 .map(|col| Column::new(Some(subquery_alias), &col.name))
64 .collect();
65 let replace_map: HashMap<&Column, &Column> =
66 cols.iter().zip(alias_cols.iter()).collect();
67
68 replace_col(expr, &replace_map)
69}
70
71#[derive(PartialEq, Eq, Hash, Debug)]
73pub(crate) struct ColumnReference<'a> {
74 pub relation: Option<&'a TableReference>,
75 pub name: &'a str,
76}
77
78impl<'a> ColumnReference<'a> {
79 pub fn new(relation: Option<&'a TableReference>, name: &'a str) -> Self {
80 Self { relation, name }
81 }
82
83 pub fn new_unqualified(name: &'a str) -> Self {
84 Self {
85 relation: None,
86 name,
87 }
88 }
89}
90
91pub(crate) fn schema_columns<'a>(schema: &'a DFSchema) -> HashSet<ColumnReference<'a>> {
93 schema
94 .iter()
95 .flat_map(|(qualifier, field)| {
96 [
97 ColumnReference::new(qualifier, field.name()),
98 ColumnReference::new_unqualified(field.name()),
100 ]
101 })
102 .collect::<HashSet<_>>()
103}
104
105pub fn log_plan(description: &str, plan: &LogicalPlan) {
107 debug!("{description}:\n{}\n", plan.display_indent());
108 trace!("{description}::\n{}\n", plan.display_indent_schema());
109}
110
111pub fn is_restrict_null_predicate<'a>(
115 predicate: Expr,
116 join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
117) -> Result<bool> {
118 if matches!(predicate, Expr::Column(_)) {
119 return Ok(true);
120 }
121
122 Ok(
125 match evaluate_expr_with_null_column(predicate, join_cols_of_predicate)? {
126 ColumnarValue::Array(array) => {
127 if array.len() == 1 {
128 let boolean_array = as_boolean_array(&array)?;
129 boolean_array.is_null(0) || !boolean_array.value(0)
130 } else {
131 false
132 }
133 }
134 ColumnarValue::Scalar(scalar) => matches!(
135 scalar,
136 ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
137 ),
138 },
139 )
140}
141
142pub fn evaluates_to_null<'a>(
147 predicate: Expr,
148 null_columns: impl IntoIterator<Item = &'a Column>,
149) -> Result<bool> {
150 if matches!(predicate, Expr::Column(_)) {
151 return Ok(true);
152 }
153
154 Ok(
155 match evaluate_expr_with_null_column(predicate, null_columns)? {
156 ColumnarValue::Array(_) => false,
157 ColumnarValue::Scalar(scalar) => scalar.is_null(),
158 },
159 )
160}
161
162fn evaluate_expr_with_null_column<'a>(
163 predicate: Expr,
164 null_columns: impl IntoIterator<Item = &'a Column>,
165) -> Result<ColumnarValue> {
166 static DUMMY_COL_NAME: &str = "?";
167 let schema = Arc::new(Schema::new(vec![Field::new(
168 DUMMY_COL_NAME,
169 DataType::Null,
170 true,
171 )]));
172 let input_schema = DFSchema::try_from(Arc::clone(&schema))?;
173 let column = new_null_array(&DataType::Null, 1);
174 let input_batch = RecordBatch::try_new(schema, vec![column])?;
175 let execution_props = ExecutionProps::default();
176 let null_column = Column::from_name(DUMMY_COL_NAME);
177
178 let join_cols_to_replace = null_columns
179 .into_iter()
180 .map(|column| (column, &null_column))
181 .collect::<HashMap<_, _>>();
182
183 let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
184 let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
185 create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?
186 .evaluate(&input_batch)
187}
188
189fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
190 let mut expr_rewrite = TypeCoercionRewriter { schema };
191 expr.rewrite(&mut expr_rewrite).data()
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use datafusion_expr::{Operator, binary_expr, case, col, in_list, is_null, lit};
198
199 #[test]
200 fn expr_is_restrict_null_predicate() -> Result<()> {
201 let test_cases = vec![
202 (col("a"), true),
204 (is_null(col("a")), false),
206 (Expr::IsNotNull(Box::new(col("a"))), true),
208 (
210 binary_expr(
211 col("a"),
212 Operator::Eq,
213 Expr::Literal(ScalarValue::Null, None),
214 ),
215 true,
216 ),
217 (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
219 (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
221 (
223 case(col("a"))
224 .when(lit(1i64), lit(true))
225 .when(lit(0i64), lit(false))
226 .otherwise(lit(ScalarValue::Null))?,
227 true,
228 ),
229 (
231 case(col("a"))
232 .when(lit(1i64), lit(true))
233 .otherwise(lit(false))?,
234 true,
235 ),
236 (
238 case(col("a"))
239 .when(lit(0i64), lit(false))
240 .otherwise(lit(true))?,
241 false,
242 ),
243 (
245 binary_expr(
246 case(col("a"))
247 .when(lit(0i64), lit(false))
248 .otherwise(lit(true))?,
249 Operator::Or,
250 lit(false),
251 ),
252 false,
253 ),
254 (
256 binary_expr(
257 case(col("a"))
258 .when(lit(0i64), lit(true))
259 .otherwise(lit(false))?,
260 Operator::Or,
261 lit(false),
262 ),
263 true,
264 ),
265 (
267 in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false),
268 true,
269 ),
270 (
272 in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
273 true,
274 ),
275 (
277 in_list(
278 col("a"),
279 vec![Expr::Literal(ScalarValue::Null, None)],
280 false,
281 ),
282 true,
283 ),
284 (
286 in_list(col("a"), vec![Expr::Literal(ScalarValue::Null, None)], true),
287 true,
288 ),
289 ];
290
291 let column_a = Column::from_name("a");
292 for (predicate, expected) in test_cases {
293 let join_cols_of_predicate = std::iter::once(&column_a);
294 let actual =
295 is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?;
296 assert_eq!(actual, expected, "{predicate}");
297 }
298
299 Ok(())
300 }
301}