datafusion_expr/logical_plan/
invariants.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use datafusion_common::{
19    internal_err, plan_err,
20    tree_node::{TreeNode, TreeNodeRecursion},
21    DFSchemaRef, Result,
22};
23
24use crate::{
25    expr::{Exists, InSubquery},
26    expr_rewriter::strip_outer_reference,
27    utils::{collect_subquery_cols, split_conjunction},
28    Aggregate, Expr, Filter, Join, JoinType, LogicalPlan, Window,
29};
30
31use super::Extension;
32
33#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
34pub enum InvariantLevel {
35    /// Invariants that are always true in DataFusion `LogicalPlan`s
36    /// such as the number of expected children and no duplicated output fields
37    Always,
38    /// Invariants that must hold true for the plan to be "executable"
39    /// such as the type and number of function arguments are correct and
40    /// that wildcards have been expanded
41    ///
42    /// To ensure a LogicalPlan satisfies the `Executable` invariants, run the
43    /// `Analyzer`
44    Executable,
45}
46
47/// Apply the [`InvariantLevel::Always`] check at the current plan node only.
48///
49/// This does not recurs to any child nodes.
50pub fn assert_always_invariants_at_current_node(plan: &LogicalPlan) -> Result<()> {
51    // Refer to <https://datafusion.apache.org/contributor-guide/specification/invariants.html#relation-name-tuples-in-logical-fields-and-logical-columns-are-unique>
52    assert_unique_field_names(plan)?;
53
54    Ok(())
55}
56
57/// Visit the plan nodes, and confirm the [`InvariantLevel::Executable`]
58/// as well as the less stringent [`InvariantLevel::Always`] checks.
59pub fn assert_executable_invariants(plan: &LogicalPlan) -> Result<()> {
60    // Always invariants
61    assert_always_invariants_at_current_node(plan)?;
62    assert_valid_extension_nodes(plan, InvariantLevel::Always)?;
63
64    // Executable invariants
65    assert_valid_extension_nodes(plan, InvariantLevel::Executable)?;
66    assert_valid_semantic_plan(plan)?;
67    Ok(())
68}
69
70/// Asserts that the query plan, and subplan, extension nodes have valid invariants.
71///
72/// Refer to [`UserDefinedLogicalNode::check_invariants`](super::UserDefinedLogicalNode)
73/// for more details of user-provided extension node invariants.
74fn assert_valid_extension_nodes(plan: &LogicalPlan, check: InvariantLevel) -> Result<()> {
75    plan.apply_with_subqueries(|plan: &LogicalPlan| {
76        if let LogicalPlan::Extension(Extension { node }) = plan {
77            node.check_invariants(check, plan)?;
78        }
79        plan.apply_expressions(|expr| {
80            // recursively look for subqueries
81            expr.apply(|expr| {
82                match expr {
83                    Expr::Exists(Exists { subquery, .. })
84                    | Expr::InSubquery(InSubquery { subquery, .. })
85                    | Expr::ScalarSubquery(subquery) => {
86                        assert_valid_extension_nodes(&subquery.subquery, check)?;
87                    }
88                    _ => {}
89                };
90                Ok(TreeNodeRecursion::Continue)
91            })
92        })
93    })
94    .map(|_| ())
95}
96
97/// Returns an error if plan, and subplans, do not have unique fields.
98///
99/// This invariant is subject to change.
100/// refer: <https://github.com/apache/datafusion/issues/13525#issuecomment-2494046463>
101fn assert_unique_field_names(plan: &LogicalPlan) -> Result<()> {
102    plan.schema().check_names()
103}
104
105/// Returns an error if the plan is not sematically valid.
106fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> {
107    assert_subqueries_are_valid(plan)?;
108
109    Ok(())
110}
111
112/// Returns an error if the plan does not have the expected schema.
113/// Ignores metadata and nullability.
114pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> {
115    let compatible = plan.schema().logically_equivalent_names_and_types(schema);
116
117    if !compatible {
118        internal_err!(
119            "Failed due to a difference in schemas: original schema: {:?}, new schema: {:?}",
120            schema,
121            plan.schema()
122        )
123    } else {
124        Ok(())
125    }
126}
127
128/// Asserts that the subqueries are structured properly with valid node placement.
129///
130/// Refer to [`check_subquery_expr`] for more details of the internal invariants.
131fn assert_subqueries_are_valid(plan: &LogicalPlan) -> Result<()> {
132    plan.apply_with_subqueries(|plan: &LogicalPlan| {
133        plan.apply_expressions(|expr| {
134            // recursively look for subqueries
135            expr.apply(|expr| {
136                match expr {
137                    Expr::Exists(Exists { subquery, .. })
138                    | Expr::InSubquery(InSubquery { subquery, .. })
139                    | Expr::ScalarSubquery(subquery) => {
140                        check_subquery_expr(plan, &subquery.subquery, expr)?;
141                    }
142                    _ => {}
143                };
144                Ok(TreeNodeRecursion::Continue)
145            })
146        })
147    })
148    .map(|_| ())
149}
150
151/// Do necessary check on subquery expressions and fail the invalid plan
152/// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions,
153///    the allowed while list: [Projection, Filter, Window, Aggregate, Join].
154/// 2) Check whether the inner plan is in the allowed inner plans list to use correlated(outer) expressions.
155/// 3) Check and validate unsupported cases to use the correlated(outer) expressions inside the subquery(inner) plans/inner expressions.
156///    For example, we do not want to support to use correlated expressions as the Join conditions in the subquery plan when the Join
157///    is a Full Out Join
158pub fn check_subquery_expr(
159    outer_plan: &LogicalPlan,
160    inner_plan: &LogicalPlan,
161    expr: &Expr,
162) -> Result<()> {
163    assert_subqueries_are_valid(inner_plan)?;
164    if let Expr::ScalarSubquery(subquery) = expr {
165        // Scalar subquery should only return one column
166        if subquery.subquery.schema().fields().len() > 1 {
167            return plan_err!(
168                "Scalar subquery should only return one column, but found {}: {}",
169                subquery.subquery.schema().fields().len(),
170                subquery.subquery.schema().field_names().join(", ")
171            );
172        }
173        // Correlated scalar subquery must be aggregated to return at most one row
174        if !subquery.outer_ref_columns.is_empty() {
175            match strip_inner_query(inner_plan) {
176                LogicalPlan::Aggregate(agg) => {
177                    check_aggregation_in_scalar_subquery(inner_plan, agg)
178                }
179                LogicalPlan::Filter(Filter { input, .. })
180                    if matches!(input.as_ref(), LogicalPlan::Aggregate(_)) =>
181                {
182                    if let LogicalPlan::Aggregate(agg) = input.as_ref() {
183                        check_aggregation_in_scalar_subquery(inner_plan, agg)
184                    } else {
185                        Ok(())
186                    }
187                }
188                _ => {
189                    if inner_plan
190                        .max_rows()
191                        .filter(|max_row| *max_row <= 1)
192                        .is_some()
193                    {
194                        Ok(())
195                    } else {
196                        plan_err!(
197                            "Correlated scalar subquery must be aggregated to return at most one row"
198                        )
199                    }
200                }
201            }?;
202            match outer_plan {
203                LogicalPlan::Projection(_)
204                | LogicalPlan::Filter(_) => Ok(()),
205                LogicalPlan::Aggregate(Aggregate { group_expr, aggr_expr, .. }) => {
206                    if group_expr.contains(expr) && !aggr_expr.contains(expr) {
207                        // TODO revisit this validation logic
208                        plan_err!(
209                            "Correlated scalar subquery in the GROUP BY clause must also be in the aggregate expressions"
210                        )
211                    } else {
212                        Ok(())
213                    }
214                }
215                _ => plan_err!(
216                    "Correlated scalar subquery can only be used in Projection, Filter, Aggregate plan nodes"
217                )
218            }?;
219        }
220        check_correlations_in_subquery(inner_plan)
221    } else {
222        if let Expr::InSubquery(subquery) = expr {
223            // InSubquery should only return one column
224            if subquery.subquery.subquery.schema().fields().len() > 1 {
225                return plan_err!(
226                    "InSubquery should only return one column, but found {}: {}",
227                    subquery.subquery.subquery.schema().fields().len(),
228                    subquery.subquery.subquery.schema().field_names().join(", ")
229                );
230            }
231        }
232        match outer_plan {
233            LogicalPlan::Projection(_)
234            | LogicalPlan::Filter(_)
235            | LogicalPlan::TableScan(_)
236            | LogicalPlan::Window(_)
237            | LogicalPlan::Aggregate(_)
238            | LogicalPlan::Join(_) => Ok(()),
239            _ => plan_err!(
240                "In/Exist subquery can only be used in \
241                Projection, Filter, TableScan, Window functions, Aggregate and Join plan nodes, \
242                but was used in [{}]",
243                outer_plan.display()
244            ),
245        }?;
246        check_correlations_in_subquery(inner_plan)
247    }
248}
249
250// Recursively check the unsupported outer references in the sub query plan.
251fn check_correlations_in_subquery(inner_plan: &LogicalPlan) -> Result<()> {
252    check_inner_plan(inner_plan)
253}
254
255// Recursively check the unsupported outer references in the sub query plan.
256#[cfg_attr(feature = "recursive_protection", recursive::recursive)]
257fn check_inner_plan(inner_plan: &LogicalPlan) -> Result<()> {
258    // We want to support as many operators as possible inside the correlated subquery
259    match inner_plan {
260        LogicalPlan::Aggregate(_) => {
261            inner_plan.apply_children(|plan| {
262                check_inner_plan(plan)?;
263                Ok(TreeNodeRecursion::Continue)
264            })?;
265            Ok(())
266        }
267        LogicalPlan::Filter(Filter { input, .. }) => check_inner_plan(input),
268        LogicalPlan::Window(window) => {
269            check_mixed_out_refer_in_window(window)?;
270            inner_plan.apply_children(|plan| {
271                check_inner_plan(plan)?;
272                Ok(TreeNodeRecursion::Continue)
273            })?;
274            Ok(())
275        }
276        LogicalPlan::Projection(_)
277        | LogicalPlan::Distinct(_)
278        | LogicalPlan::Sort(_)
279        | LogicalPlan::Union(_)
280        | LogicalPlan::TableScan(_)
281        | LogicalPlan::EmptyRelation(_)
282        | LogicalPlan::Limit(_)
283        | LogicalPlan::Values(_)
284        | LogicalPlan::Subquery(_)
285        | LogicalPlan::SubqueryAlias(_)
286        | LogicalPlan::Unnest(_) => {
287            inner_plan.apply_children(|plan| {
288                check_inner_plan(plan)?;
289                Ok(TreeNodeRecursion::Continue)
290            })?;
291            Ok(())
292        }
293        LogicalPlan::Join(Join {
294            left,
295            right,
296            join_type,
297            ..
298        }) => match join_type {
299            JoinType::Inner => {
300                inner_plan.apply_children(|plan| {
301                    check_inner_plan(plan)?;
302                    Ok(TreeNodeRecursion::Continue)
303                })?;
304                Ok(())
305            }
306            JoinType::Left
307            | JoinType::LeftSemi
308            | JoinType::LeftAnti
309            | JoinType::LeftMark => {
310                check_inner_plan(left)?;
311                check_no_outer_references(right)
312            }
313            JoinType::Right
314            | JoinType::RightSemi
315            | JoinType::RightAnti
316            | JoinType::RightMark => {
317                check_no_outer_references(left)?;
318                check_inner_plan(right)
319            }
320            JoinType::Full => {
321                inner_plan.apply_children(|plan| {
322                    check_no_outer_references(plan)?;
323                    Ok(TreeNodeRecursion::Continue)
324                })?;
325                Ok(())
326            }
327        },
328        LogicalPlan::Extension(_) => Ok(()),
329        plan => check_no_outer_references(plan),
330    }
331}
332
333fn check_no_outer_references(inner_plan: &LogicalPlan) -> Result<()> {
334    if inner_plan.contains_outer_reference() {
335        plan_err!(
336            "Accessing outer reference columns is not allowed in the plan: {}",
337            inner_plan.display()
338        )
339    } else {
340        Ok(())
341    }
342}
343
344fn check_aggregation_in_scalar_subquery(
345    inner_plan: &LogicalPlan,
346    agg: &Aggregate,
347) -> Result<()> {
348    if agg.aggr_expr.is_empty() {
349        return plan_err!(
350            "Correlated scalar subquery must be aggregated to return at most one row"
351        );
352    }
353    if !agg.group_expr.is_empty() {
354        let correlated_exprs = get_correlated_expressions(inner_plan)?;
355        let inner_subquery_cols =
356            collect_subquery_cols(&correlated_exprs, agg.input.schema())?;
357        let mut group_columns = agg
358            .group_expr
359            .iter()
360            .map(|group| Ok(group.column_refs().into_iter().cloned().collect::<Vec<_>>()))
361            .collect::<Result<Vec<_>>>()?
362            .into_iter()
363            .flatten();
364
365        if !group_columns.all(|group| inner_subquery_cols.contains(&group)) {
366            // Group BY columns must be a subset of columns in the correlated expressions
367            return plan_err!(
368                "A GROUP BY clause in a scalar correlated subquery cannot contain non-correlated columns"
369            );
370        }
371    }
372    Ok(())
373}
374
375fn strip_inner_query(inner_plan: &LogicalPlan) -> &LogicalPlan {
376    match inner_plan {
377        LogicalPlan::Projection(projection) => {
378            strip_inner_query(projection.input.as_ref())
379        }
380        LogicalPlan::SubqueryAlias(alias) => strip_inner_query(alias.input.as_ref()),
381        other => other,
382    }
383}
384
385fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result<Vec<Expr>> {
386    let mut exprs = vec![];
387    inner_plan.apply_with_subqueries(|plan| {
388        if let LogicalPlan::Filter(Filter { predicate, .. }) = plan {
389            let (correlated, _): (Vec<_>, Vec<_>) = split_conjunction(predicate)
390                .into_iter()
391                .partition(|e| e.contains_outer());
392
393            for expr in correlated {
394                exprs.push(strip_outer_reference(expr.clone()));
395            }
396        }
397        Ok(TreeNodeRecursion::Continue)
398    })?;
399    Ok(exprs)
400}
401
402/// Check whether the window expressions contain a mixture of out reference columns and inner columns
403fn check_mixed_out_refer_in_window(window: &Window) -> Result<()> {
404    let mixed = window
405        .window_expr
406        .iter()
407        .any(|win_expr| win_expr.contains_outer() && win_expr.any_column_refs());
408    if mixed {
409        plan_err!(
410            "Window expressions should not contain a mixed of outer references and inner columns"
411        )
412    } else {
413        Ok(())
414    }
415}
416
417#[cfg(test)]
418mod test {
419    use std::cmp::Ordering;
420    use std::sync::Arc;
421
422    use crate::{Extension, UserDefinedLogicalNodeCore};
423    use datafusion_common::{DFSchema, DFSchemaRef};
424
425    use super::*;
426
427    #[derive(Debug, PartialEq, Eq, Hash)]
428    struct MockUserDefinedLogicalPlan {
429        empty_schema: DFSchemaRef,
430    }
431
432    impl PartialOrd for MockUserDefinedLogicalPlan {
433        fn partial_cmp(&self, _other: &Self) -> Option<Ordering> {
434            None
435        }
436    }
437
438    impl UserDefinedLogicalNodeCore for MockUserDefinedLogicalPlan {
439        fn name(&self) -> &str {
440            "MockUserDefinedLogicalPlan"
441        }
442
443        fn inputs(&self) -> Vec<&LogicalPlan> {
444            vec![]
445        }
446
447        fn schema(&self) -> &DFSchemaRef {
448            &self.empty_schema
449        }
450
451        fn expressions(&self) -> Vec<Expr> {
452            vec![]
453        }
454
455        fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
456            write!(f, "MockUserDefinedLogicalPlan")
457        }
458
459        fn with_exprs_and_inputs(
460            &self,
461            _exprs: Vec<Expr>,
462            _inputs: Vec<LogicalPlan>,
463        ) -> Result<Self> {
464            Ok(Self {
465                empty_schema: Arc::clone(&self.empty_schema),
466            })
467        }
468
469        fn supports_limit_pushdown(&self) -> bool {
470            false // Disallow limit push-down by default
471        }
472    }
473
474    #[test]
475    fn wont_fail_extension_plan() {
476        let plan = LogicalPlan::Extension(Extension {
477            node: Arc::new(MockUserDefinedLogicalPlan {
478                empty_schema: DFSchemaRef::new(DFSchema::empty()),
479            }),
480        });
481
482        check_inner_plan(&plan).unwrap();
483    }
484}