1use std::collections::{BTreeSet, HashMap, HashSet};
21
22use crate::{OptimizerConfig, OptimizerRule};
23
24use crate::analyzer::type_coercion::TypeCoercionRewriter;
25use arrow::array::{new_null_array, Array, RecordBatch};
26use arrow::datatypes::{DataType, Field, Schema};
27use datafusion_common::cast::as_boolean_array;
28use datafusion_common::tree_node::{TransformedResult, TreeNode};
29use datafusion_common::{Column, DFSchema, Result, ScalarValue};
30use datafusion_expr::execution_props::ExecutionProps;
31use datafusion_expr::expr_rewriter::replace_col;
32use datafusion_expr::{logical_plan::LogicalPlan, ColumnarValue, Expr};
33use datafusion_physical_expr::create_physical_expr;
34use log::{debug, trace};
35use std::sync::Arc;
36
37pub use datafusion_expr::expr_rewriter::NamePreserver;
40
41#[deprecated(
49 since = "40.0.0",
50 note = "please use OptimizerRule::apply_order with ApplyOrder::BottomUp instead"
51)]
52pub fn optimize_children(
53 optimizer: &impl OptimizerRule,
54 plan: &LogicalPlan,
55 config: &dyn OptimizerConfig,
56) -> Result<Option<LogicalPlan>> {
57 let mut new_inputs = Vec::with_capacity(plan.inputs().len());
58 let mut plan_is_changed = false;
59 for input in plan.inputs() {
60 if optimizer.supports_rewrite() {
61 let new_input = optimizer.rewrite(input.clone(), config)?;
62 plan_is_changed = plan_is_changed || new_input.transformed;
63 new_inputs.push(new_input.data);
64 } else {
65 #[allow(deprecated)]
66 let new_input = optimizer.try_optimize(input, config)?;
67 plan_is_changed = plan_is_changed || new_input.is_some();
68 new_inputs.push(new_input.unwrap_or_else(|| input.clone()))
69 }
70 }
71 if plan_is_changed {
72 let exprs = plan.expressions();
73 plan.with_new_exprs(exprs, new_inputs).map(Some)
74 } else {
75 Ok(None)
76 }
77}
78
79pub(crate) fn has_all_column_refs(expr: &Expr, schema_cols: &HashSet<Column>) -> bool {
81 let column_refs = expr.column_refs();
82 schema_cols
84 .iter()
85 .filter(|c| column_refs.contains(c))
86 .count()
87 == column_refs.len()
88}
89
90pub(crate) fn replace_qualified_name(
91 expr: Expr,
92 cols: &BTreeSet<Column>,
93 subquery_alias: &str,
94) -> Result<Expr> {
95 let alias_cols: Vec<Column> = cols
96 .iter()
97 .map(|col| Column::new(Some(subquery_alias), &col.name))
98 .collect();
99 let replace_map: HashMap<&Column, &Column> =
100 cols.iter().zip(alias_cols.iter()).collect();
101
102 replace_col(expr, &replace_map)
103}
104
105pub fn log_plan(description: &str, plan: &LogicalPlan) {
107 debug!("{description}:\n{}\n", plan.display_indent());
108 trace!("{description}::\n{}\n", plan.display_indent_schema());
109}
110
111pub fn is_restrict_null_predicate<'a>(
115 predicate: Expr,
116 join_cols_of_predicate: impl IntoIterator<Item = &'a Column>,
117) -> Result<bool> {
118 if matches!(predicate, Expr::Column(_)) {
119 return Ok(true);
120 }
121
122 static DUMMY_COL_NAME: &str = "?";
123 let schema = Schema::new(vec![Field::new(DUMMY_COL_NAME, DataType::Null, true)]);
124 let input_schema = DFSchema::try_from(schema.clone())?;
125 let column = new_null_array(&DataType::Null, 1);
126 let input_batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![column])?;
127 let execution_props = ExecutionProps::default();
128 let null_column = Column::from_name(DUMMY_COL_NAME);
129
130 let join_cols_to_replace = join_cols_of_predicate
131 .into_iter()
132 .map(|column| (column, &null_column))
133 .collect::<HashMap<_, _>>();
134
135 let replaced_predicate = replace_col(predicate, &join_cols_to_replace)?;
136 let coerced_predicate = coerce(replaced_predicate, &input_schema)?;
137 let phys_expr =
138 create_physical_expr(&coerced_predicate, &input_schema, &execution_props)?;
139
140 let result_type = phys_expr.data_type(&schema)?;
141 if !matches!(&result_type, DataType::Boolean) {
142 return Ok(false);
143 }
144
145 Ok(match phys_expr.evaluate(&input_batch)? {
148 ColumnarValue::Array(array) => {
149 if array.len() == 1 {
150 let boolean_array = as_boolean_array(&array)?;
151 boolean_array.is_null(0) || !boolean_array.value(0)
152 } else {
153 false
154 }
155 }
156 ColumnarValue::Scalar(scalar) => matches!(
157 scalar,
158 ScalarValue::Boolean(None) | ScalarValue::Boolean(Some(false))
159 ),
160 })
161}
162
163fn coerce(expr: Expr, schema: &DFSchema) -> Result<Expr> {
164 let mut expr_rewrite = TypeCoercionRewriter { schema };
165 expr.rewrite(&mut expr_rewrite).data()
166}
167
168#[cfg(test)]
169mod tests {
170 use super::*;
171 use datafusion_expr::{binary_expr, case, col, in_list, is_null, lit, Operator};
172
173 #[test]
174 fn expr_is_restrict_null_predicate() -> Result<()> {
175 let test_cases = vec![
176 (col("a"), true),
178 (is_null(col("a")), false),
180 (Expr::IsNotNull(Box::new(col("a"))), true),
182 (
184 binary_expr(col("a"), Operator::Eq, Expr::Literal(ScalarValue::Null)),
185 true,
186 ),
187 (binary_expr(col("a"), Operator::Gt, lit(8i64)), true),
189 (binary_expr(col("a"), Operator::LtEq, lit(8i32)), true),
191 (
193 case(col("a"))
194 .when(lit(1i64), lit(true))
195 .when(lit(0i64), lit(false))
196 .otherwise(lit(ScalarValue::Null))?,
197 true,
198 ),
199 (
201 case(col("a"))
202 .when(lit(1i64), lit(true))
203 .otherwise(lit(false))?,
204 true,
205 ),
206 (
208 case(col("a"))
209 .when(lit(0i64), lit(false))
210 .otherwise(lit(true))?,
211 false,
212 ),
213 (
215 binary_expr(
216 case(col("a"))
217 .when(lit(0i64), lit(false))
218 .otherwise(lit(true))?,
219 Operator::Or,
220 lit(false),
221 ),
222 false,
223 ),
224 (
226 binary_expr(
227 case(col("a"))
228 .when(lit(0i64), lit(true))
229 .otherwise(lit(false))?,
230 Operator::Or,
231 lit(false),
232 ),
233 true,
234 ),
235 (
237 in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], false),
238 true,
239 ),
240 (
242 in_list(col("a"), vec![lit(1i64), lit(2i64), lit(3i64)], true),
243 true,
244 ),
245 (
247 in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], false),
248 true,
249 ),
250 (
252 in_list(col("a"), vec![Expr::Literal(ScalarValue::Null)], true),
253 true,
254 ),
255 ];
256
257 let column_a = Column::from_name("a");
258 for (predicate, expected) in test_cases {
259 let join_cols_of_predicate = std::iter::once(&column_a);
260 let actual =
261 is_restrict_null_predicate(predicate.clone(), join_cols_of_predicate)?;
262 assert_eq!(actual, expected, "{}", predicate);
263 }
264
265 Ok(())
266 }
267}