datafusion_expr/logical_plan/
invariants.rs1use datafusion_common::{
19 DFSchemaRef, Result, assert_or_internal_err, plan_err,
20 tree_node::{TreeNode, TreeNodeRecursion},
21};
22
23use crate::{
24 Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
25 expr::{Exists, InSubquery},
26 expr_rewriter::strip_outer_reference,
27 utils::{collect_subquery_cols, split_conjunction},
28};
29
30use super::Extension;
31
32#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
33pub enum InvariantLevel {
34 Always,
37 Executable,
44}
45
46pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
50 assert_unique_field_names(plan)?;
52
53 Ok(())
54}
55
56pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
59 assert_always_invariants_at_current_node(plan)?;
61 assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
62
63 assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
65 assert_valid_semantic_plan(plan)?;
66 Ok(())
67}
68
69fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
74 plan.apply_with_subqueries(|plan: &LogicalPlan| {
75 if let LogicalPlan::Extension(Extension { node }) = plan {
76 node.check_invariants(check)?;
77 }
78 plan.apply_expressions(|expr| {
79 expr.apply(|expr| {
81 match expr {
82 Expr::Exists(Exists { subquery, .. })
83 | Expr::InSubquery(InSubquery { subquery, .. })
84 | Expr::ScalarSubquery(subquery) => {
85 assert_valid_extension_nodes(&subquery.subquery, check)?;
86 }
87 _ => {}
88 };
89 Ok(TreeNodeRecursion::Continue)
90 })
91 })
92 })
93 .map(|_| ())
94}
95
96fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> {
101 plan.schema().check_names()
102}
103
104fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> {
106 assert_subqueries_are_valid(plan)?;
107
108 Ok(())
109}
110
111pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> {
114 let compatible = plan.schema().logically_equivalent_names_and_types(schema);
115
116 assert_or_internal_err!(
117 compatible,
118 "Failed due to a difference in schemas: original schema: {:?}, new schema: {:?}",
119 schema,
120 plan.schema()
121 );
122 Ok(())
123}
124
125fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
129 plan.apply_with_subqueries(|plan: &LogicalPlan| {
130 plan.apply_expressions(|expr| {
131 expr.apply(|expr| {
133 match expr {
134 Expr::Exists(Exists { subquery, .. })
135 | Expr::InSubquery(InSubquery { subquery, .. })
136 | Expr::ScalarSubquery(subquery) => {
137 check_subquery_expr(plan, &subquery.subquery, expr)?;
138 }
139 _ => {}
140 };
141 Ok(TreeNodeRecursion::Continue)
142 })
143 })
144 })
145 .map(|_| ())
146}
147
148pub fn check_subquery_expr(
156 outer_plan: &LogicalPlan,
157 inner_plan: &LogicalPlan,
158 expr: &Expr,
159) -> Result<()> {
160 assert_subqueries_are_valid(inner_plan)?;
161 if let Expr::ScalarSubquery(subquery) = expr {
162 if subquery.subquery.schema().fields().len() > 1 {
164 return plan_err!(
165 "Scalar subquery should only return one column, but found {}: {}",
166 subquery.subquery.schema().fields().len(),
167 subquery.subquery.schema().field_names().join(", ")
168 );
169 }
170 if !subquery.outer_ref_columns.is_empty() {
172 match strip_inner_query(inner_plan) {
173 LogicalPlan::Aggregate(agg) => {
174 check_aggregation_in_scalar_subquery(inner_plan, agg)
175 }
176 LogicalPlan::Filter(Filter { input, .. })
177 if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) =>
178 {
179 if let LogicalPlan::Aggregate(agg) = input.as_ref() {
180 check_aggregation_in_scalar_subquery(inner_plan, agg)
181 } else {
182 Ok(())
183 }
184 }
185 _ => {
186 if inner_plan
187 .max_rows()
188 .filter(|max_row| *max_row <= 1)
189 .is_some()
190 {
191 Ok(())
192 } else {
193 plan_err!(
194 "Correlated scalar subquery must be aggregated to return at most one row"
195 )
196 }
197 }
198 }?;
199 match outer_plan {
200 LogicalPlan::Projection(_) | LogicalPlan::Filter(_) => Ok(()),
201 LogicalPlan::Aggregate(Aggregate {
202 group_expr,
203 aggr_expr,
204 ..
205 }) => {
206 if group_expr.contains(expr) && !aggr_expr.contains(expr) {
207 plan_err!(
209 "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions"
210 )
211 } else {
212 Ok(())
213 }
214 }
215 _ => plan_err!(
216 "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes"
217 ),
218 }?;
219 }
220 check_correlations_in_subquery(inner_plan)
221 } else {
222 if let Expr::InSubquery(subquery) = expr {
223 if subquery.subquery.subquery.schema().fields().len() > 1 {
225 return plan_err!(
226 "InSubquery should only return one column, but found {}: {}",
227 subquery.subquery.subquery.schema().fields().len(),
228 subquery.subquery.subquery.schema().field_names().join(", ")
229 );
230 }
231 }
232 match outer_plan {
233 LogicalPlan::Projection(_)
234 | LogicalPlan::Filter(_)
235 | LogicalPlan::TableScan(_)
236 | LogicalPlan::Window(_)
237 | LogicalPlan::Aggregate(_)
238 | LogicalPlan::Join(_) => Ok(()),
239 _ => plan_err!(
240 "In/Exist subquery can only be used in \
241 Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \
242 but was used in [{}]",
243 outer_plan.display()
244 ),
245 }?;
246 check_correlations_in_subquery(inner_plan)
247 }
248}
249
250fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
252 check_inner_plan(inner_plan)
253}
254
255#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
257fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
258 match inner_plan {
260 LogicalPlan::Aggregate(_) => {
261 inner_plan.apply_children(|plan| {
262 check_inner_plan(plan)?;
263 Ok(TreeNodeRecursion::Continue)
264 })?;
265 Ok(())
266 }
267 LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
268 LogicalPlan::Window(window) => {
269 check_mixed_out_refer_in_window(window)?;
270 inner_plan.apply_children(|plan| {
271 check_inner_plan(plan)?;
272 Ok(TreeNodeRecursion::Continue)
273 })?;
274 Ok(())
275 }
276 LogicalPlan::Projection(_)
277 | LogicalPlan::Distinct(_)
278 | LogicalPlan::Sort(_)
279 | LogicalPlan::Union(_)
280 | LogicalPlan::TableScan(_)
281 | LogicalPlan::EmptyRelation(_)
282 | LogicalPlan::Limit(_)
283 | LogicalPlan::Values(_)
284 | LogicalPlan::Subquery(_)
285 | LogicalPlan::SubqueryAlias(_)
286 | LogicalPlan::Unnest(_) => {
287 inner_plan.apply_children(|plan| {
288 check_inner_plan(plan)?;
289 Ok(TreeNodeRecursion::Continue)
290 })?;
291 Ok(())
292 }
293 LogicalPlan::Join(Join {
294 left,
295 right,
296 join_type,
297 ..
298 }) => match join_type {
299 JoinType::Inner => {
300 inner_plan.apply_children(|plan| {
301 check_inner_plan(plan)?;
302 Ok(TreeNodeRecursion::Continue)
303 })?;
304 Ok(())
305 }
306 JoinType::Left
307 | JoinType::LeftSemi
308 | JoinType::LeftAnti
309 | JoinType::LeftMark => {
310 check_inner_plan(left)?;
311 check_no_outer_references(right)
312 }
313 JoinType::Right
314 | JoinType::RightSemi
315 | JoinType::RightAnti
316 | JoinType::RightMark => {
317 check_no_outer_references(left)?;
318 check_inner_plan(right)
319 }
320 JoinType::Full => {
321 inner_plan.apply_children(|plan| {
322 check_no_outer_references(plan)?;
323 Ok(TreeNodeRecursion::Continue)
324 })?;
325 Ok(())
326 }
327 },
328 LogicalPlan::Extension(_) => Ok(()),
329 plan => check_no_outer_references(plan),
330 }
331}
332
333fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
334 if inner_plan.contains_outer_reference() {
335 plan_err!(
336 "Accessing outer reference columns is not allowed in the plan: {}",
337 inner_plan.display()
338 )
339 } else {
340 Ok(())
341 }
342}
343
344fn check_aggregation_in_scalar_subquery(
345 inner_plan: &LogicalPlan,
346 agg: &Aggregate,
347) -> Result<()> {
348 if agg.aggr_expr.is_empty() {
349 return plan_err!(
350 "Correlated scalar subquery must be aggregated to return at most one row"
351 );
352 }
353 if !agg.group_expr.is_empty() {
354 let correlated_exprs = get_correlated_expressions(inner_plan)?;
355 let inner_subquery_cols =
356 collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
357 let mut group_columns = agg
358 .group_expr
359 .iter()
360 .map(|group| Ok(group.column_refs().into_iter().cloned().collect::<Vec<_>>()))
361 .collect::<Result<Vec<_>>>()?
362 .into_iter()
363 .flatten();
364
365 if !group_columns.all(|group| inner_subquery_cols.contains(&group)) {
366 return plan_err!(
368 "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"
369 );
370 }
371 }
372 Ok(())
373}
374
375fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
376 match inner_plan {
377 LogicalPlan::Projection(projection) => {
378 strip_inner_query(projection.input.as_ref())
379 }
380 LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()),
381 other => other,
382 }
383}
384
385fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
386 let mut exprs = vec![];
387 inner_plan.apply_with_subqueries(|plan| {
388 if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
389 let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
390 .into_iter()
391 .partition(|e| e.contains_outer());
392
393 for expr in correlated {
394 exprs.push(strip_outer_reference(expr.clone()));
395 }
396 }
397 Ok(TreeNodeRecursion::Continue)
398 })?;
399 Ok(exprs)
400}
401
402fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
404 let mixed = window
405 .window_expr
406 .iter()
407 .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
408 if mixed {
409 plan_err!(
410 "Window expressions should not contain a mixed of outer references and inner columns"
411 )
412 } else {
413 Ok(())
414 }
415}
416
417#[cfg(test)]
418mod test {
419 use std::cmp::Ordering;
420 use std::sync::Arc;
421
422 use crate::{Extension, UserDefinedLogicalNodeCore};
423 use datafusion_common::{DFSchema, DFSchemaRef};
424
425 use super::*;
426
427 #[derive(Debug, PartialEq, Eq, Hash)]
428 struct MockUserDefinedLogicalPlan {
429 empty_schema: DFSchemaRef,
430 }
431
432 impl PartialOrd for MockUserDefinedLogicalPlan {
433 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
434 None
435 }
436 }
437
438 impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan {
439 fn name(&self) -> &str {
440 "MockUserDefinedLogicalPlan"
441 }
442
443 fn inputs(&self) -> Vec<&LogicalPlan> {
444 vec![]
445 }
446
447 fn schema(&self) -> &DFSchemaRef {
448 &self.empty_schema
449 }
450
451 fn expressions(&self) -> Vec<Expr> {
452 vec![]
453 }
454
455 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
456 write!(f, "MockUserDefinedLogicalPlan")
457 }
458
459 fn with_exprs_and_inputs(
460 &self,
461 _exprs: Vec<Expr>,
462 _inputs: Vec<LogicalPlan>,
463 ) -> Result<Self> {
464 Ok(Self {
465 empty_schema: Arc::clone(&self.empty_schema),
466 })
467 }
468
469 fn supports_limit_pushdown(&self) -> bool {
470 false }
472 }
473
474 #[test]
475 fn wont_fail_extension_plan() {
476 let plan = LogicalPlan::Extension(Extension {
477 node: Arc::new(MockUserDefinedLogicalPlan {
478 empty_schema: DFSchemaRef::new(DFSchema::empty()),
479 }),
480 });
481
482 check_inner_plan(&plan).unwrap();
483 }
484}