datafusion_expr/logical_plan/
invariants.rs1use datafusion_common::{
19 DFSchemaRef, Result, assert_or_internal_err, plan_err,
20 tree_node::{TreeNode, TreeNodeRecursion},
21};
22
23use crate::{
24 Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
25 expr::{Exists, InSubquery, SetComparison},
26 expr_rewriter::strip_outer_reference,
27 utils::{collect_subquery_cols, split_conjunction},
28};
29
30use super::Extension;
31
32#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
33pub enum InvariantLevel {
34 Always,
37 Executable,
44}
45
46pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
50 assert_unique_field_names(plan)?;
52
53 Ok(())
54}
55
56pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
59 assert_always_invariants_at_current_node(plan)?;
61 assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
62
63 assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
65 assert_valid_semantic_plan(plan)?;
66 Ok(())
67}
68
69fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
74 plan.apply_with_subqueries(|plan: &LogicalPlan| {
75 if let LogicalPlan::Extension(Extension { node }) = plan {
76 node.check_invariants(check)?;
77 }
78 plan.apply_expressions(|expr| {
79 expr.apply(|expr| {
81 match expr {
82 Expr::Exists(Exists { subquery, .. })
83 | Expr::InSubquery(InSubquery { subquery, .. })
84 | Expr::SetComparison(SetComparison { subquery, .. })
85 | Expr::ScalarSubquery(subquery) => {
86 assert_valid_extension_nodes(&subquery.subquery, check)?;
87 }
88 _ => {}
89 };
90 Ok(TreeNodeRecursion::Continue)
91 })
92 })
93 })
94 .map(|_| ())
95}
96
97fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> {
102 plan.schema().check_names()
103}
104
105fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> {
107 assert_subqueries_are_valid(plan)?;
108
109 Ok(())
110}
111
112pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> {
115 let compatible = plan.schema().logically_equivalent_names_and_types(schema);
116
117 assert_or_internal_err!(
118 compatible,
119 "Failed due to a difference in schemas: original schema: {:?}, new schema: {:?}",
120 schema,
121 plan.schema()
122 );
123 Ok(())
124}
125
126fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
130 plan.apply_with_subqueries(|plan: &LogicalPlan| {
131 plan.apply_expressions(|expr| {
132 expr.apply(|expr| {
134 match expr {
135 Expr::Exists(Exists { subquery, .. })
136 | Expr::InSubquery(InSubquery { subquery, .. })
137 | Expr::SetComparison(SetComparison { subquery, .. })
138 | Expr::ScalarSubquery(subquery) => {
139 check_subquery_expr(plan, &subquery.subquery, expr)?;
140 }
141 _ => {}
142 };
143 Ok(TreeNodeRecursion::Continue)
144 })
145 })
146 })
147 .map(|_| ())
148}
149
150pub fn check_subquery_expr(
158 outer_plan: &LogicalPlan,
159 inner_plan: &LogicalPlan,
160 expr: &Expr,
161) -> Result<()> {
162 assert_subqueries_are_valid(inner_plan)?;
163 if let Expr::ScalarSubquery(subquery) = expr {
164 if subquery.subquery.schema().fields().len() > 1 {
166 return plan_err!(
167 "Scalar subquery should only return one column, but found {}: {}",
168 subquery.subquery.schema().fields().len(),
169 subquery.subquery.schema().field_names().join(", ")
170 );
171 }
172 if !subquery.outer_ref_columns.is_empty() {
174 match strip_inner_query(inner_plan) {
175 LogicalPlan::Aggregate(agg) => {
176 check_aggregation_in_scalar_subquery(inner_plan, agg)
177 }
178 LogicalPlan::Filter(Filter { input, .. })
179 if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) =>
180 {
181 if let LogicalPlan::Aggregate(agg) = input.as_ref() {
182 check_aggregation_in_scalar_subquery(inner_plan, agg)
183 } else {
184 Ok(())
185 }
186 }
187 _ => {
188 if inner_plan
189 .max_rows()
190 .filter(|max_row| *max_row <= 1)
191 .is_some()
192 {
193 Ok(())
194 } else {
195 plan_err!(
196 "Correlated scalar subquery must be aggregated to return at most one row"
197 )
198 }
199 }
200 }?;
201 match outer_plan {
202 LogicalPlan::Projection(_) | LogicalPlan::Filter(_) => Ok(()),
203 LogicalPlan::Aggregate(Aggregate {
204 group_expr,
205 aggr_expr,
206 ..
207 }) => {
208 if group_expr.contains(expr) && !aggr_expr.contains(expr) {
209 plan_err!(
211 "Correlated scalar subquery in the GROUP BY clause must \
212 also be in the aggregate expressions"
213 )
214 } else {
215 Ok(())
216 }
217 }
218 _ => plan_err!(
219 "Correlated scalar subquery can only be used in Projection, \
220 Filter, Aggregate plan nodes"
221 ),
222 }?;
223 }
224 check_correlations_in_subquery(inner_plan)
225 } else {
226 if let Expr::InSubquery(subquery) = expr {
227 if subquery.subquery.subquery.schema().fields().len() > 1 {
229 return plan_err!(
230 "InSubquery should only return one column, but found {}: {}",
231 subquery.subquery.subquery.schema().fields().len(),
232 subquery.subquery.subquery.schema().field_names().join(", ")
233 );
234 }
235 }
236 if let Expr::SetComparison(set_comparison) = expr
237 && set_comparison.subquery.subquery.schema().fields().len() > 1
238 {
239 return plan_err!(
240 "Set comparison subquery should only return one column, but found {}: {}",
241 set_comparison.subquery.subquery.schema().fields().len(),
242 set_comparison
243 .subquery
244 .subquery
245 .schema()
246 .field_names()
247 .join(", ")
248 );
249 }
250 match outer_plan {
251 LogicalPlan::Projection(_)
252 | LogicalPlan::Filter(_)
253 | LogicalPlan::TableScan(_)
254 | LogicalPlan::Window(_)
255 | LogicalPlan::Aggregate(_)
256 | LogicalPlan::Join(_) => Ok(()),
257 _ => plan_err!(
258 "In/Exist/SetComparison subquery can only be used in \
259 Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \
260 but was used in [{}]",
261 outer_plan.display()
262 ),
263 }?;
264 check_correlations_in_subquery(inner_plan)
265 }
266}
267
268fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
270 check_inner_plan(inner_plan)
271}
272
273#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
275fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
276 match inner_plan {
278 LogicalPlan::Aggregate(_) => {
279 inner_plan.apply_children(|plan| {
280 check_inner_plan(plan)?;
281 Ok(TreeNodeRecursion::Continue)
282 })?;
283 Ok(())
284 }
285 LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
286 LogicalPlan::Window(window) => {
287 check_mixed_out_refer_in_window(window)?;
288 inner_plan.apply_children(|plan| {
289 check_inner_plan(plan)?;
290 Ok(TreeNodeRecursion::Continue)
291 })?;
292 Ok(())
293 }
294 LogicalPlan::Projection(_)
295 | LogicalPlan::Distinct(_)
296 | LogicalPlan::Sort(_)
297 | LogicalPlan::Union(_)
298 | LogicalPlan::TableScan(_)
299 | LogicalPlan::EmptyRelation(_)
300 | LogicalPlan::Limit(_)
301 | LogicalPlan::Values(_)
302 | LogicalPlan::Subquery(_)
303 | LogicalPlan::SubqueryAlias(_)
304 | LogicalPlan::Unnest(_) => {
305 inner_plan.apply_children(|plan| {
306 check_inner_plan(plan)?;
307 Ok(TreeNodeRecursion::Continue)
308 })?;
309 Ok(())
310 }
311 LogicalPlan::Join(Join {
312 left,
313 right,
314 join_type,
315 ..
316 }) => match join_type {
317 JoinType::Inner => {
318 inner_plan.apply_children(|plan| {
319 check_inner_plan(plan)?;
320 Ok(TreeNodeRecursion::Continue)
321 })?;
322 Ok(())
323 }
324 JoinType::Left
325 | JoinType::LeftSemi
326 | JoinType::LeftAnti
327 | JoinType::LeftMark => {
328 check_inner_plan(left)?;
329 check_no_outer_references(right)
330 }
331 JoinType::Right
332 | JoinType::RightSemi
333 | JoinType::RightAnti
334 | JoinType::RightMark => {
335 check_no_outer_references(left)?;
336 check_inner_plan(right)
337 }
338 JoinType::Full => {
339 inner_plan.apply_children(|plan| {
340 check_no_outer_references(plan)?;
341 Ok(TreeNodeRecursion::Continue)
342 })?;
343 Ok(())
344 }
345 },
346 LogicalPlan::Extension(_) => Ok(()),
347 plan => check_no_outer_references(plan),
348 }
349}
350
351fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
352 if inner_plan.contains_outer_reference() {
353 plan_err!(
354 "Accessing outer reference columns is not allowed in the plan: {}",
355 inner_plan.display()
356 )
357 } else {
358 Ok(())
359 }
360}
361
362fn check_aggregation_in_scalar_subquery(
363 inner_plan: &LogicalPlan,
364 agg: &Aggregate,
365) -> Result<()> {
366 if agg.aggr_expr.is_empty() {
367 return plan_err!(
368 "Correlated scalar subquery must be aggregated to return at most one row"
369 );
370 }
371 if !agg.group_expr.is_empty() {
372 let correlated_exprs = get_correlated_expressions(inner_plan)?;
373 let inner_subquery_cols =
374 collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
375 let mut group_columns = agg
376 .group_expr
377 .iter()
378 .map(|group| Ok(group.column_refs().into_iter().cloned().collect::<Vec<_>>()))
379 .collect::<Result<Vec<_>>>()?
380 .into_iter()
381 .flatten();
382
383 if !group_columns.all(|group| inner_subquery_cols.contains(&group)) {
384 return plan_err!(
386 "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"
387 );
388 }
389 }
390 Ok(())
391}
392
393fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
394 match inner_plan {
395 LogicalPlan::Projection(projection) => {
396 strip_inner_query(projection.input.as_ref())
397 }
398 LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()),
399 other => other,
400 }
401}
402
403fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
404 let mut exprs = vec![];
405 inner_plan.apply_with_subqueries(|plan| {
406 if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
407 let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
408 .into_iter()
409 .partition(|e| e.contains_outer());
410
411 for expr in correlated {
412 exprs.push(strip_outer_reference(expr.clone()));
413 }
414 }
415 Ok(TreeNodeRecursion::Continue)
416 })?;
417 Ok(exprs)
418}
419
420fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
422 let mixed = window
423 .window_expr
424 .iter()
425 .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
426 if mixed {
427 plan_err!(
428 "Window expressions should not contain a mixed of outer references and inner columns"
429 )
430 } else {
431 Ok(())
432 }
433}
434
435#[cfg(test)]
436mod test {
437 use std::cmp::Ordering;
438 use std::sync::Arc;
439
440 use crate::{Extension, UserDefinedLogicalNodeCore};
441 use datafusion_common::{DFSchema, DFSchemaRef};
442
443 use super::*;
444
445 #[derive(Debug, PartialEq, Eq, Hash)]
446 struct MockUserDefinedLogicalPlan {
447 empty_schema: DFSchemaRef,
448 }
449
450 impl PartialOrd for MockUserDefinedLogicalPlan {
451 fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
452 None
453 }
454 }
455
456 impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan {
457 fn name(&self) -> &str {
458 "MockUserDefinedLogicalPlan"
459 }
460
461 fn inputs(&self) -> Vec<&LogicalPlan> {
462 vec![]
463 }
464
465 fn schema(&self) -> &DFSchemaRef {
466 &self.empty_schema
467 }
468
469 fn expressions(&self) -> Vec<Expr> {
470 vec![]
471 }
472
473 fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
474 write!(f, "MockUserDefinedLogicalPlan")
475 }
476
477 fn with_exprs_and_inputs(
478 &self,
479 _exprs: Vec<Expr>,
480 _inputs: Vec<LogicalPlan>,
481 ) -> Result<Self> {
482 Ok(Self {
483 empty_schema: Arc::clone(&self.empty_schema),
484 })
485 }
486
487 fn supports_limit_pushdown(&self) -> bool {
488 false }
490 }
491
492 #[test]
493 fn wont_fail_extension_plan() {
494 let plan = LogicalPlan::Extension(Extension {
495 node: Arc::new(MockUserDefinedLogicalPlan {
496 empty_schema: DFSchemaRef::new(DFSchema::empty()),
497 }),
498 });
499
500 check_inner_plan(&plan).unwrap();
501 }
502}