Skip to main content

datafusion_expr/logical_plan/
invariants.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion_common::{
19    DFSchemaRef, Result, assert_or_internal_err, plan_err,
20    tree_node::{TreeNode, TreeNodeRecursion},
21};
22
23use crate::{
24    Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
25    expr::{Exists, InSubquery, SetComparison},
26    expr_rewriter::strip_outer_reference,
27    utils::{collect_subquery_cols, split_conjunction},
28};
29
30use super::Extension;
31
32#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
33pub enum InvariantLevel {
34    /// Invariants that are always true in DataFusion `LogicalPlan`s
35    /// such as the number of expected children and no duplicated output fields
36    Always,
37    /// Invariants that must hold true for the plan to be "executable"
38    /// such as the type and number of function arguments are correct and
39    /// that wildcards have been expanded
40    ///
41    /// To ensure a LogicalPlan satisfies the `Executable` invariants, run the
42    /// `Analyzer`
43    Executable,
44}
45
46/// Apply the [`InvariantLevel::Always`] check at the current plan node only.
47///
48/// This does not recurs to any child nodes.
49pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
50    // Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
51    assert_unique_field_names(plan)?;
52
53    Ok(())
54}
55
56/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
57/// as well as the less stringent [`InvariantLevel::Always`] checks.
58pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
59    // Always invariants
60    assert_always_invariants_at_current_node(plan)?;
61    assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
62
63    // Executable invariants
64    assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
65    assert_valid_semantic_plan(plan)?;
66    Ok(())
67}
68
69/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
70///
71/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
72/// for more details of user-provided extension node invariants.
73fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
74    plan.apply_with_subqueries(|plan: &LogicalPlan| {
75        if let LogicalPlan::Extension(Extension { node }) = plan {
76            node.check_invariants(check)?;
77        }
78        plan.apply_expressions(|expr| {
79            // recursively look for subqueries
80            expr.apply(|expr| {
81                match expr {
82                    Expr::Exists(Exists { subquery, .. })
83                    | Expr::InSubquery(InSubquery { subquery, .. })
84                    | Expr::SetComparison(SetComparison { subquery, .. })
85                    | Expr::ScalarSubquery(subquery) => {
86                        assert_valid_extension_nodes(&subquery.subquery, check)?;
87                    }
88                    _ => {}
89                };
90                Ok(TreeNodeRecursion::Continue)
91            })
92        })
93    })
94    .map(|_| ())
95}
96
97/// Returns an error if plan, and subplans, do not have unique fields.
98///
99/// This invariant is subject to change.
100/// refer: <https://github.com/apache/datafusion/issues/13525#issuecomment-2494046463>
101fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> {
102    plan.schema().check_names()
103}
104
105/// Returns an error if the plan is not semantically valid.
106fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> {
107    assert_subqueries_are_valid(plan)?;
108
109    Ok(())
110}
111
112/// Returns an error if the plan does not have the expected schema.
113/// Ignores metadata and nullability.
114pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> {
115    let compatible = plan.schema().logically_equivalent_names_and_types(schema);
116
117    assert_or_internal_err!(
118        compatible,
119        "Failed due to a difference in schemas: original schema: {:?}, new schema: {:?}",
120        schema,
121        plan.schema()
122    );
123    Ok(())
124}
125
126/// Asserts that the subqueries are structured properly with valid node placement.
127///
128/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
129fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
130    plan.apply_with_subqueries(|plan: &LogicalPlan| {
131        plan.apply_expressions(|expr| {
132            // recursively look for subqueries
133            expr.apply(|expr| {
134                match expr {
135                    Expr::Exists(Exists { subquery, .. })
136                    | Expr::InSubquery(InSubquery { subquery, .. })
137                    | Expr::SetComparison(SetComparison { subquery, .. })
138                    | Expr::ScalarSubquery(subquery) => {
139                        check_subquery_expr(plan, &subquery.subquery, expr)?;
140                    }
141                    _ => {}
142                };
143                Ok(TreeNodeRecursion::Continue)
144            })
145        })
146    })
147    .map(|_| ())
148}
149
150/// Do necessary check on subquery expressions and fail the invalid plan
151/// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions,
152///    the allowed while list: [Projection, Filter, Window, Aggregate, Join].
153/// 2) Check whether the inner plan is in the allowed inner plans list to use correlated(outer) expressions.
154/// 3) Check and validate unsupported cases to use the correlated(outer) expressions inside the subquery(inner) plans/inner expressions.
155///    For example, we do not want to support to use correlated expressions as the Join conditions in the subquery plan when the Join
156///    is a Full Out Join
157pub fn check_subquery_expr(
158    outer_plan: &LogicalPlan,
159    inner_plan: &LogicalPlan,
160    expr: &Expr,
161) -> Result<()> {
162    assert_subqueries_are_valid(inner_plan)?;
163    if let Expr::ScalarSubquery(subquery) = expr {
164        // Scalar subquery should only return one column
165        if subquery.subquery.schema().fields().len() > 1 {
166            return plan_err!(
167                "Scalar subquery should only return one column, but found {}: {}",
168                subquery.subquery.schema().fields().len(),
169                subquery.subquery.schema().field_names().join(", ")
170            );
171        }
172        // Correlated scalar subquery must be aggregated to return at most one row
173        if !subquery.outer_ref_columns.is_empty() {
174            match strip_inner_query(inner_plan) {
175                LogicalPlan::Aggregate(agg) => {
176                    check_aggregation_in_scalar_subquery(inner_plan, agg)
177                }
178                LogicalPlan::Filter(Filter { input, .. })
179                    if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) =>
180                {
181                    if let LogicalPlan::Aggregate(agg) = input.as_ref() {
182                        check_aggregation_in_scalar_subquery(inner_plan, agg)
183                    } else {
184                        Ok(())
185                    }
186                }
187                _ => {
188                    if inner_plan
189                        .max_rows()
190                        .filter(|max_row| *max_row <= 1)
191                        .is_some()
192                    {
193                        Ok(())
194                    } else {
195                        plan_err!(
196                            "Correlated scalar subquery must be aggregated to return at most one row"
197                        )
198                    }
199                }
200            }?;
201            match outer_plan {
202                LogicalPlan::Projection(_) | LogicalPlan::Filter(_) => Ok(()),
203                LogicalPlan::Aggregate(Aggregate {
204                    group_expr,
205                    aggr_expr,
206                    ..
207                }) => {
208                    if group_expr.contains(expr) && !aggr_expr.contains(expr) {
209                        // TODO revisit this validation logic
210                        plan_err!(
211                            "Correlated scalar subquery in the GROUP BY clause must \
212                            also be in the aggregate expressions"
213                        )
214                    } else {
215                        Ok(())
216                    }
217                }
218                _ => plan_err!(
219                    "Correlated scalar subquery can only be used in Projection, \
220                    Filter, Aggregate plan nodes"
221                ),
222            }?;
223        }
224        check_correlations_in_subquery(inner_plan)
225    } else {
226        if let Expr::InSubquery(subquery) = expr {
227            // InSubquery should only return one column
228            if subquery.subquery.subquery.schema().fields().len() > 1 {
229                return plan_err!(
230                    "InSubquery should only return one column, but found {}: {}",
231                    subquery.subquery.subquery.schema().fields().len(),
232                    subquery.subquery.subquery.schema().field_names().join(", ")
233                );
234            }
235        }
236        if let Expr::SetComparison(set_comparison) = expr
237            && set_comparison.subquery.subquery.schema().fields().len() > 1
238        {
239            return plan_err!(
240                "Set comparison subquery should only return one column, but found {}: {}",
241                set_comparison.subquery.subquery.schema().fields().len(),
242                set_comparison
243                    .subquery
244                    .subquery
245                    .schema()
246                    .field_names()
247                    .join(", ")
248            );
249        }
250        match outer_plan {
251            LogicalPlan::Projection(_)
252            | LogicalPlan::Filter(_)
253            | LogicalPlan::TableScan(_)
254            | LogicalPlan::Window(_)
255            | LogicalPlan::Aggregate(_)
256            | LogicalPlan::Join(_) => Ok(()),
257            _ => plan_err!(
258                "In/Exist/SetComparison subquery can only be used in \
259                Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \
260                but was used in [{}]",
261                outer_plan.display()
262            ),
263        }?;
264        check_correlations_in_subquery(inner_plan)
265    }
266}
267
268// Recursively check the unsupported outer references in the sub query plan.
269fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
270    check_inner_plan(inner_plan)
271}
272
273// Recursively check the unsupported outer references in the sub query plan.
274#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
275fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
276    // We want to support as many operators as possible inside the correlated subquery
277    match inner_plan {
278        LogicalPlan::Aggregate(_) => {
279            inner_plan.apply_children(|plan| {
280                check_inner_plan(plan)?;
281                Ok(TreeNodeRecursion::Continue)
282            })?;
283            Ok(())
284        }
285        LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
286        LogicalPlan::Window(window) => {
287            check_mixed_out_refer_in_window(window)?;
288            inner_plan.apply_children(|plan| {
289                check_inner_plan(plan)?;
290                Ok(TreeNodeRecursion::Continue)
291            })?;
292            Ok(())
293        }
294        LogicalPlan::Projection(_)
295        | LogicalPlan::Distinct(_)
296        | LogicalPlan::Sort(_)
297        | LogicalPlan::Union(_)
298        | LogicalPlan::TableScan(_)
299        | LogicalPlan::EmptyRelation(_)
300        | LogicalPlan::Limit(_)
301        | LogicalPlan::Values(_)
302        | LogicalPlan::Subquery(_)
303        | LogicalPlan::SubqueryAlias(_)
304        | LogicalPlan::Unnest(_) => {
305            inner_plan.apply_children(|plan| {
306                check_inner_plan(plan)?;
307                Ok(TreeNodeRecursion::Continue)
308            })?;
309            Ok(())
310        }
311        LogicalPlan::Join(Join {
312            left,
313            right,
314            join_type,
315            ..
316        }) => match join_type {
317            JoinType::Inner => {
318                inner_plan.apply_children(|plan| {
319                    check_inner_plan(plan)?;
320                    Ok(TreeNodeRecursion::Continue)
321                })?;
322                Ok(())
323            }
324            JoinType::Left
325            | JoinType::LeftSemi
326            | JoinType::LeftAnti
327            | JoinType::LeftMark => {
328                check_inner_plan(left)?;
329                check_no_outer_references(right)
330            }
331            JoinType::Right
332            | JoinType::RightSemi
333            | JoinType::RightAnti
334            | JoinType::RightMark => {
335                check_no_outer_references(left)?;
336                check_inner_plan(right)
337            }
338            JoinType::Full => {
339                inner_plan.apply_children(|plan| {
340                    check_no_outer_references(plan)?;
341                    Ok(TreeNodeRecursion::Continue)
342                })?;
343                Ok(())
344            }
345        },
346        LogicalPlan::Extension(_) => Ok(()),
347        plan => check_no_outer_references(plan),
348    }
349}
350
351fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
352    if inner_plan.contains_outer_reference() {
353        plan_err!(
354            "Accessing outer reference columns is not allowed in the plan: {}",
355            inner_plan.display()
356        )
357    } else {
358        Ok(())
359    }
360}
361
362fn check_aggregation_in_scalar_subquery(
363    inner_plan: &LogicalPlan,
364    agg: &Aggregate,
365) -> Result<()> {
366    if agg.aggr_expr.is_empty() {
367        return plan_err!(
368            "Correlated scalar subquery must be aggregated to return at most one row"
369        );
370    }
371    if !agg.group_expr.is_empty() {
372        let correlated_exprs = get_correlated_expressions(inner_plan)?;
373        let inner_subquery_cols =
374            collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
375        let mut group_columns = agg
376            .group_expr
377            .iter()
378            .map(|group| Ok(group.column_refs().into_iter().cloned().collect::<Vec<_>>()))
379            .collect::<Result<Vec<_>>>()?
380            .into_iter()
381            .flatten();
382
383        if !group_columns.all(|group| inner_subquery_cols.contains(&group)) {
384            // Group BY columns must be a subset of columns in the correlated expressions
385            return plan_err!(
386                "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"
387            );
388        }
389    }
390    Ok(())
391}
392
393fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
394    match inner_plan {
395        LogicalPlan::Projection(projection) => {
396            strip_inner_query(projection.input.as_ref())
397        }
398        LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()),
399        other => other,
400    }
401}
402
403fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
404    let mut exprs = vec![];
405    inner_plan.apply_with_subqueries(|plan| {
406        if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
407            let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
408                .into_iter()
409                .partition(|e| e.contains_outer());
410
411            for expr in correlated {
412                exprs.push(strip_outer_reference(expr.clone()));
413            }
414        }
415        Ok(TreeNodeRecursion::Continue)
416    })?;
417    Ok(exprs)
418}
419
420/// Check whether the window expressions contain a mixture of out reference columns and inner columns
421fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
422    let mixed = window
423        .window_expr
424        .iter()
425        .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
426    if mixed {
427        plan_err!(
428            "Window expressions should not contain a mixed of outer references and inner columns"
429        )
430    } else {
431        Ok(())
432    }
433}
434
435#[cfg(test)]
436mod test {
437    use std::cmp::Ordering;
438    use std::sync::Arc;
439
440    use crate::{Extension, UserDefinedLogicalNodeCore};
441    use datafusion_common::{DFSchema, DFSchemaRef};
442
443    use super::*;
444
445    #[derive(Debug, PartialEq, Eq, Hash)]
446    struct MockUserDefinedLogicalPlan {
447        empty_schema: DFSchemaRef,
448    }
449
450    impl PartialOrd for MockUserDefinedLogicalPlan {
451        fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
452            None
453        }
454    }
455
456    impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan {
457        fn name(&self) -> &str {
458            "MockUserDefinedLogicalPlan"
459        }
460
461        fn inputs(&self) -> Vec<&LogicalPlan> {
462            vec![]
463        }
464
465        fn schema(&self) -> &DFSchemaRef {
466            &self.empty_schema
467        }
468
469        fn expressions(&self) -> Vec<Expr> {
470            vec![]
471        }
472
473        fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
474            write!(f, "MockUserDefinedLogicalPlan")
475        }
476
477        fn with_exprs_and_inputs(
478            &self,
479            _exprs: Vec<Expr>,
480            _inputs: Vec<LogicalPlan>,
481        ) -> Result<Self> {
482            Ok(Self {
483                empty_schema: Arc::clone(&self.empty_schema),
484            })
485        }
486
487        fn supports_limit_pushdown(&self) -> bool {
488            false // Disallow limit push-down by default
489        }
490    }
491
492    #[test]
493    fn wont_fail_extension_plan() {
494        let plan = LogicalPlan::Extension(Extension {
495            node: Arc::new(MockUserDefinedLogicalPlan {
496                empty_schema: DFSchemaRef::new(DFSchema::empty()),
497            }),
498        });
499
500        check_inner_plan(&plan).unwrap();
501    }
502}