Skip to main content

datafusion_sql/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL Utility Functions
19
20use std::vec;
21
22use arrow::datatypes::{
23    DECIMAL_DEFAULT_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType,
24};
25use datafusion_common::tree_node::{
26    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29    Column, DFSchemaRef, Diagnostic, HashMap, Result, ScalarValue,
30    assert_or_internal_err, exec_datafusion_err, exec_err, internal_err, plan_err,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34    Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38    ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, col, expr_vec_fmt,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44/// Make a best-effort attempt at resolving all columns in the expression tree
45pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46    expr.clone()
47        .transform_up(|nested_expr| {
48            match nested_expr {
49                Expr::Column(col) => {
50                    let (qualifier, field) =
51                        plan.schema().qualified_field_from_column(&col)?;
52                    Ok(Transformed::yes(Expr::Column(Column::from((
53                        qualifier, field,
54                    )))))
55                }
56                _ => {
57                    // keep recursing
58                    Ok(Transformed::no(nested_expr))
59                }
60            }
61        })
62        .data()
63}
64
65/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
66///
67/// For example, the expression `a + b < 1` would require, as input, the 2
68/// individual columns, `a` and `b`. But, if the base expressions already
69/// contain the `a + b` result, then that may be used in lieu of the `a` and
70/// `b` columns.
71///
72/// This is useful in the context of a query like:
73///
74/// SELECT a + b < 1 ... GROUP BY a + b
75///
76/// where post-aggregation, `a + b` need not be a projection against the
77/// individual columns `a` and `b`, but rather it is a projection against the
78/// `a + b` found in the GROUP BY.
79pub(crate) fn rebase_expr(
80    expr: &Expr,
81    base_exprs: &[Expr],
82    plan: &LogicalPlan,
83) -> Result<Expr> {
84    expr.clone()
85        .transform_down(|nested_expr| {
86            if base_exprs.contains(&nested_expr) {
87                Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88            } else {
89                Ok(Transformed::no(nested_expr))
90            }
91        })
92        .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
97    Projection,
98    Having,
99    Qualify,
100    OrderBy,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub(crate) enum CheckColumnsSatisfyExprsPurpose {
105    Aggregate(CheckColumnsMustReferenceAggregatePurpose),
106}
107
108impl CheckColumnsSatisfyExprsPurpose {
109    fn message_prefix(&self) -> &'static str {
110        match self {
111            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
112                "Column in SELECT must be in GROUP BY or an aggregate function"
113            }
114            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
115                "Column in HAVING must be in GROUP BY or an aggregate function"
116            }
117            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
118                "Column in QUALIFY must be in GROUP BY or an aggregate function"
119            }
120            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::OrderBy) => {
121                "Column in ORDER BY must be in GROUP BY or an aggregate function"
122            }
123        }
124    }
125
126    fn diagnostic_message(&self, expr: &Expr) -> String {
127        format!(
128            "'{expr}' must appear in GROUP BY clause because it's not an aggregate expression"
129        )
130    }
131}
132
133/// Determines if the set of `Expr`'s are a valid projection on the input
134/// `Expr::Column`'s.
135pub(crate) fn check_columns_satisfy_exprs(
136    columns: &[Expr],
137    exprs: &[Expr],
138    purpose: CheckColumnsSatisfyExprsPurpose,
139) -> Result<()> {
140    columns.iter().try_for_each(|c| match c {
141        Expr::Column(_) => Ok(()),
142        _ => internal_err!("Expr::Column are required"),
143    })?;
144    let column_exprs = find_column_exprs(exprs);
145    for e in &column_exprs {
146        match e {
147            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
148                for e in exprs {
149                    check_column_satisfies_expr(columns, e, purpose)?;
150                }
151            }
152            Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
153                for e in exprs {
154                    check_column_satisfies_expr(columns, e, purpose)?;
155                }
156            }
157            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
158                for exprs in lists_of_exprs {
159                    for e in exprs {
160                        check_column_satisfies_expr(columns, e, purpose)?;
161                    }
162                }
163            }
164            _ => check_column_satisfies_expr(columns, e, purpose)?,
165        }
166    }
167    Ok(())
168}
169
170fn check_column_satisfies_expr(
171    columns: &[Expr],
172    expr: &Expr,
173    purpose: CheckColumnsSatisfyExprsPurpose,
174) -> Result<()> {
175    if !columns.contains(expr) {
176        let diagnostic = Diagnostic::new_error(
177            purpose.diagnostic_message(expr),
178            expr.spans().and_then(|spans| spans.first()),
179        )
180        .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
181
182        return plan_err!(
183            "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
184            purpose.message_prefix(),
185            expr,
186            expr_vec_fmt!(columns);
187            diagnostic=diagnostic
188        );
189    }
190    Ok(())
191}
192
193/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
194/// aliasing.
195pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
196    exprs
197        .iter()
198        .filter_map(|expr| match expr {
199            Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
200            _ => None,
201        })
202        .collect::<HashMap<String, Expr>>()
203}
204
205/// Given an expression that's literal int encoding position, lookup the corresponding expression
206/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
207/// otherwise, returns planning error.
208/// If input expression is not an int literal, returns expression as-is.
209pub(crate) fn resolve_positions_to_exprs(
210    expr: Expr,
211    select_exprs: &[Expr],
212) -> Result<Expr> {
213    match expr {
214        // sql_expr_to_logical_expr maps number to i64
215        // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
216        Expr::Literal(ScalarValue::Int64(Some(position)), _)
217            if position > 0_i64 && position <= select_exprs.len() as i64 =>
218        {
219            let index = (position - 1) as usize;
220            let select_expr = &select_exprs[index];
221            Ok(match select_expr {
222                Expr::Alias(Alias { expr, .. }) => *expr.clone(),
223                _ => select_expr.clone(),
224            })
225        }
226        Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
227            "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
228            position,
229            select_exprs.len()
230        ),
231        _ => Ok(expr),
232    }
233}
234
235/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
236/// alias' underlying `Expr`.
237pub(crate) fn resolve_aliases_to_exprs(
238    expr: Expr,
239    aliases: &HashMap<String, Expr>,
240) -> Result<Expr> {
241    expr.transform_up(|nested_expr| match nested_expr {
242        Expr::Column(c) if c.relation.is_none() => {
243            if let Some(aliased_expr) = aliases.get(&c.name) {
244                Ok(Transformed::yes(aliased_expr.clone()))
245            } else {
246                Ok(Transformed::no(Expr::Column(c)))
247            }
248        }
249        _ => Ok(Transformed::no(nested_expr)),
250    })
251    .data()
252}
253
254/// Given a slice of window expressions sharing the same sort key, find their common partition
255/// keys.
256pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
257    let all_partition_keys = window_exprs
258        .iter()
259        .map(|expr| match expr {
260            Expr::WindowFunction(window_fun) => {
261                let WindowFunction {
262                    params: WindowFunctionParams { partition_by, .. },
263                    ..
264                } = window_fun.as_ref();
265                Ok(partition_by)
266            }
267            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
268                Expr::WindowFunction(window_fun) => {
269                    let WindowFunction {
270                        params: WindowFunctionParams { partition_by, .. },
271                        ..
272                    } = window_fun.as_ref();
273                    Ok(partition_by)
274                }
275                expr => exec_err!("Impossibly got non-window expr {expr:?}"),
276            },
277            expr => exec_err!("Impossibly got non-window expr {expr:?}"),
278        })
279        .collect::<Result<Vec<_>>>()?;
280    let result = all_partition_keys
281        .iter()
282        .min_by_key(|s| s.len())
283        .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
284    Ok(result)
285}
286
287/// Returns a validated `DataType` for the specified precision and
288/// scale
289pub(crate) fn make_decimal_type(
290    precision: Option<u64>,
291    scale: Option<u64>,
292) -> Result<DataType> {
293    // postgres like behavior
294    let (precision, scale) = match (precision, scale) {
295        (Some(p), Some(s)) => (p as u8, s as i8),
296        (Some(p), None) => (p as u8, 0),
297        (None, Some(_)) => {
298            return plan_err!("Cannot specify only scale for decimal data type");
299        }
300        (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
301    };
302
303    if precision == 0
304        || precision > DECIMAL256_MAX_PRECISION
305        || scale.unsigned_abs() > precision
306    {
307        plan_err!(
308            "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
309        )
310    } else if precision > DECIMAL128_MAX_PRECISION
311        && precision <= DECIMAL256_MAX_PRECISION
312    {
313        Ok(DataType::Decimal256(precision, scale))
314    } else {
315        Ok(DataType::Decimal128(precision, scale))
316    }
317}
318
319/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
320pub(crate) fn normalize_ident(id: Ident) -> String {
321    match id.quote_style {
322        Some(_) => id.value,
323        None => id.value.to_ascii_lowercase(),
324    }
325}
326
327pub(crate) fn value_to_string(value: &Value) -> Option<String> {
328    match value {
329        Value::SingleQuotedString(s) => Some(s.to_string()),
330        Value::DollarQuotedString(s) => Some(s.to_string()),
331        Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
332        Value::UnicodeStringLiteral(s) => Some(s.to_string()),
333        Value::EscapedStringLiteral(s) => Some(s.to_string()),
334        Value::QuoteDelimitedStringLiteral(s)
335        | Value::NationalQuoteDelimitedStringLiteral(s) => Some(s.value.to_string()),
336        Value::DoubleQuotedString(_)
337        | Value::NationalStringLiteral(_)
338        | Value::SingleQuotedByteStringLiteral(_)
339        | Value::DoubleQuotedByteStringLiteral(_)
340        | Value::TripleSingleQuotedString(_)
341        | Value::TripleDoubleQuotedString(_)
342        | Value::TripleSingleQuotedByteStringLiteral(_)
343        | Value::TripleDoubleQuotedByteStringLiteral(_)
344        | Value::SingleQuotedRawStringLiteral(_)
345        | Value::DoubleQuotedRawStringLiteral(_)
346        | Value::TripleSingleQuotedRawStringLiteral(_)
347        | Value::TripleDoubleQuotedRawStringLiteral(_)
348        | Value::HexStringLiteral(_)
349        | Value::Null
350        | Value::Placeholder(_) => None,
351    }
352}
353
354pub(crate) fn rewrite_recursive_unnests_bottom_up(
355    input: &LogicalPlan,
356    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
357    inner_projection_exprs: &mut Vec<Expr>,
358    original_exprs: &[Expr],
359) -> Result<Vec<Expr>> {
360    Ok(original_exprs
361        .iter()
362        .map(|expr| {
363            rewrite_recursive_unnest_bottom_up(
364                input,
365                unnest_placeholder_columns,
366                inner_projection_exprs,
367                expr,
368            )
369        })
370        .collect::<Result<Vec<_>>>()?
371        .into_iter()
372        .flatten()
373        .collect::<Vec<_>>())
374}
375
376pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
377
378/*
379This is only useful when used with transform down up
380A full example of how the transformation works:
381 */
382struct RecursiveUnnestRewriter<'a> {
383    input_schema: &'a DFSchemaRef,
384    root_expr: &'a Expr,
385    // Useful to detect which child expr is a part of/ not a part of unnest operation
386    top_most_unnest: Option<Unnest>,
387    consecutive_unnest: Vec<Option<Unnest>>,
388    inner_projection_exprs: &'a mut Vec<Expr>,
389    columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
390    transformed_root_exprs: Option<Vec<Expr>>,
391}
392impl RecursiveUnnestRewriter<'_> {
393    /// This struct stores the history of expr
394    /// during its tree-traversal with a notation of
395    /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
396    /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
397    ///
398    /// The first item will be the inner most expr
399    fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
400        self.consecutive_unnest
401            .iter()
402            .rev()
403            .skip_while(|item| item.is_none())
404            .take_while(|item| item.is_some())
405            .to_owned()
406            .cloned()
407            .map(|item| item.unwrap())
408            .collect()
409    }
410
411    /// Check if the current expression is at the root level for struct unnest purposes.
412    /// This is true if:
413    /// 1. The expression IS the root expression, OR
414    /// 2. The root expression is an Alias wrapping this expression
415    ///
416    /// This allows `unnest(struct_col) AS alias` to work, where the alias is simply
417    /// ignored for struct unnest (matching DuckDB behavior).
418    fn is_at_struct_allowed_root(&self, expr: &Expr) -> bool {
419        if expr == self.root_expr {
420            return true;
421        }
422        // Allow struct unnest when root is an alias wrapping the unnest
423        if let Expr::Alias(Alias { expr: inner, .. }) = self.root_expr {
424            return inner.as_ref() == expr;
425        }
426        false
427    }
428
429    fn transform(
430        &mut self,
431        level: usize,
432        alias_name: String,
433        expr_in_unnest: &Expr,
434        struct_allowed: bool,
435    ) -> Result<Vec<Expr>> {
436        let inner_expr_name = expr_in_unnest.schema_name().to_string();
437
438        // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
439        // inside unnest execution, each column inside the inner projection
440        // will be transformed into new columns. Thus we need to keep track of these placeholding column names
441        let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
442        let post_unnest_name =
443            format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
444        // This is due to the fact that unnest transformation should keep the original
445        // column name as is, to comply with group by and order by
446        let placeholder_column = Column::from_name(placeholder_name.clone());
447        let field = expr_in_unnest.to_field(self.input_schema)?.1;
448        let data_type = field.data_type();
449
450        match data_type {
451            DataType::Struct(inner_fields) => {
452                assert_or_internal_err!(
453                    struct_allowed,
454                    "unnest on struct can only be applied at the root level of select expression"
455                );
456                push_projection_dedupl(
457                    self.inner_projection_exprs,
458                    expr_in_unnest.clone().alias(placeholder_name.clone()),
459                );
460                self.columns_unnestings
461                    .insert(Column::from_name(placeholder_name.clone()), None);
462                Ok(get_struct_unnested_columns(&placeholder_name, inner_fields)
463                    .into_iter()
464                    .map(Expr::Column)
465                    .collect())
466            }
467            DataType::List(_)
468            | DataType::FixedSizeList(_, _)
469            | DataType::LargeList(_) => {
470                push_projection_dedupl(
471                    self.inner_projection_exprs,
472                    expr_in_unnest.clone().alias(placeholder_name.clone()),
473                );
474
475                let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
476                let list_unnesting = self
477                    .columns_unnestings
478                    .entry(placeholder_column)
479                    .or_insert(Some(vec![]));
480                let unnesting = ColumnUnnestList {
481                    output_column: Column::from_name(post_unnest_name),
482                    depth: level,
483                };
484                let list_unnestings = list_unnesting.as_mut().unwrap();
485                if !list_unnestings.contains(&unnesting) {
486                    list_unnestings.push(unnesting);
487                }
488                Ok(vec![post_unnest_expr])
489            }
490            _ => {
491                internal_err!("unnest on non-list or struct type is not supported")
492            }
493        }
494    }
495}
496
497impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
498    type Node = Expr;
499
500    /// This downward traversal needs to keep track of:
501    /// - Whether or not some unnest expr has been visited from the top until the current node
502    /// - If some unnest expr has been visited, maintain a stack of such information, this
503    ///   is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
504    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
505        if let Expr::Unnest(ref unnest_expr) = expr {
506            let field = unnest_expr.expr.to_field(self.input_schema)?.1;
507            let data_type = field.data_type();
508            self.consecutive_unnest.push(Some(unnest_expr.clone()));
509            // if expr inside unnest is a struct, do not consider
510            // the next unnest as consecutive unnest (if any)
511            // meaning unnest(unnest(struct_arr_col)) can't
512            // be interpreted as unnest(struct_arr_col, depth:=2)
513            // but has to be split into multiple unnest logical plan instead
514            // a.k.a:
515            // - unnest(struct_col)
516            //      unnest(struct_arr_col) as struct_col
517
518            if let DataType::Struct(_) = data_type {
519                self.consecutive_unnest.push(None);
520            }
521            if self.top_most_unnest.is_none() {
522                self.top_most_unnest = Some(unnest_expr.clone());
523            }
524
525            Ok(Transformed::no(expr))
526        } else {
527            self.consecutive_unnest.push(None);
528            Ok(Transformed::no(expr))
529        }
530    }
531
532    /// The rewriting only happens when the traversal has reached the top-most unnest expr
533    /// within a sequence of consecutive unnest exprs node
534    ///
535    /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
536    /// ```text
537    ///                         ┌──────────────────┐
538    ///                         │    binaryexpr    │
539    ///                         │                  │
540    ///                         └──────────────────┘
541    ///                f_down  / /            │ │
542    ///                       / / f_up        │ │
543    ///                      / /        f_down│ │f_up
544    ///                  unnest               │ │
545    ///                                       │ │
546    ///       f_down  / / f_up(rewriting)     │ │
547    ///              / /
548    ///             / /                      unnest
549    ///         unnest
550    ///                           f_down  / / f_up(rewriting)
551    /// f_down / /f_up                   / /
552    ///       / /                       / /
553    ///      / /                    unnest
554    ///   column1
555    ///                     f_down / /f_up
556    ///                           / /
557    ///                          / /
558    ///                       column2
559    /// ```
560    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
561        if let Expr::Unnest(ref traversing_unnest) = expr {
562            if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
563                self.top_most_unnest = None;
564            }
565            // Find inside consecutive_unnest, the sequence of continuous unnest exprs
566
567            // Get the latest consecutive unnest exprs
568            // and check if current upward traversal is the returning to the root expr
569            // for example given a expr `unnest(unnest(col))` then the traversal happens like:
570            // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
571            // the result of such traversal is unnest(col, depth:=2)
572            let unnest_stack = self.get_latest_consecutive_unnest();
573
574            // This traversal has reached the top most unnest again
575            // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
576            // -> Unnest(2nd) -> Unnest(top) a.k.a here
577            // Thus
578            // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
579            if traversing_unnest == unnest_stack.last().unwrap() {
580                let most_inner = unnest_stack.first().unwrap();
581                let inner_expr = most_inner.expr.as_ref();
582                // unnest(unnest(struct_arr_col)) is not allow to be done recursively
583                // it needs to be split into multiple unnest logical plan
584                // unnest(struct_arr)
585                //  unnest(struct_arr_col) as struct_arr
586                // instead of unnest(struct_arr_col, depth = 2)
587
588                let unnest_recursion = unnest_stack.len();
589                let struct_allowed =
590                    self.is_at_struct_allowed_root(&expr) && unnest_recursion == 1;
591
592                let mut transformed_exprs = self.transform(
593                    unnest_recursion,
594                    expr.schema_name().to_string(),
595                    inner_expr,
596                    struct_allowed,
597                )?;
598                // Only set transformed_root_exprs for struct unnest (which returns multiple expressions).
599                // For list unnest (single expression), we let the normal rewrite handle the alias.
600                if struct_allowed && transformed_exprs.len() > 1 {
601                    self.transformed_root_exprs = Some(transformed_exprs.clone());
602                }
603                return Ok(Transformed::new(
604                    transformed_exprs.swap_remove(0),
605                    true,
606                    TreeNodeRecursion::Continue,
607                ));
608            }
609        } else {
610            self.consecutive_unnest.push(None);
611        }
612
613        // For column exprs that are not descendants of any unnest node
614        // retain their projection
615        // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
616        // this condition can be checked by maintaining an Option<top most unnest>
617        if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
618            push_projection_dedupl(self.inner_projection_exprs, expr.clone());
619        }
620
621        Ok(Transformed::no(expr))
622    }
623}
624
625fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
626    let schema_name = expr.schema_name().to_string();
627    if !projection
628        .iter()
629        .any(|e| e.schema_name().to_string() == schema_name)
630    {
631        projection.push(expr);
632    }
633}
634/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
635/// Given an expression which contains unnest expr as one of its children,
636/// Try transform depends on unnest type
637/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
638/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
639///
640/// The transformed exprs will be used in the outer projection
641/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
642/// is done only for the bottom expression
643pub(crate) fn rewrite_recursive_unnest_bottom_up(
644    input: &LogicalPlan,
645    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
646    inner_projection_exprs: &mut Vec<Expr>,
647    original_expr: &Expr,
648) -> Result<Vec<Expr>> {
649    let mut rewriter = RecursiveUnnestRewriter {
650        input_schema: input.schema(),
651        root_expr: original_expr,
652        top_most_unnest: None,
653        consecutive_unnest: vec![],
654        inner_projection_exprs,
655        columns_unnestings: unnest_placeholder_columns,
656        transformed_root_exprs: None,
657    };
658
659    // This transformation is only done for list unnest
660    // struct unnest is done at the root level, and at the later stage
661    // because the syntax of TreeNode only support transform into 1 Expr, while
662    // Unnest struct will be transformed into multiple Exprs
663    // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
664    //
665    // The transformation looks like:
666    // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
667    // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
668    let Transformed {
669        data: transformed_expr,
670        transformed,
671        tnr: _,
672    } = original_expr.clone().rewrite(&mut rewriter)?;
673
674    if !transformed {
675        // TODO: remove the next line after `Expr::Wildcard` is removed
676        #[expect(deprecated)]
677        if matches!(&transformed_expr, Expr::Column(_))
678            || matches!(&transformed_expr, Expr::Wildcard { .. })
679        {
680            push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
681            Ok(vec![transformed_expr])
682        } else {
683            // We need to evaluate the expr in the inner projection,
684            // outer projection just select its name
685            let column_name = transformed_expr.schema_name().to_string();
686            push_projection_dedupl(inner_projection_exprs, transformed_expr);
687            Ok(vec![Expr::Column(Column::from_name(column_name))])
688        }
689    } else {
690        if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
691            return Ok(transformed_root_exprs);
692        }
693        Ok(vec![transformed_expr])
694    }
695}
696
697#[cfg(test)]
698mod tests {
699    use std::{ops::Add, sync::Arc};
700
701    use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
702    use datafusion_common::{Column, DFSchema, Result};
703    use datafusion_expr::{
704        ColumnUnnestList, EmptyRelation, LogicalPlan, col, lit, unnest,
705    };
706    use datafusion_functions::core::expr_ext::FieldAccessor;
707    use datafusion_functions_aggregate::expr_fn::count;
708
709    use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
710    use indexmap::IndexMap;
711
712    fn column_unnests_eq(
713        l: Vec<&str>,
714        r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
715    ) {
716        let r_formatted: Vec<String> = r
717            .iter()
718            .map(|i| match i.1 {
719                None => format!("{}", i.0),
720                Some(vec) => format!(
721                    "{}=>[{}]",
722                    i.0,
723                    vec.iter()
724                        .map(|i| format!("{i}"))
725                        .collect::<Vec<String>>()
726                        .join(", ")
727                ),
728            })
729            .collect();
730        let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
731        assert_eq!(l_formatted, r_formatted);
732    }
733
734    #[test]
735    fn test_transform_bottom_unnest_recursive() -> Result<()> {
736        let schema = Schema::new(vec![
737            Field::new(
738                "3d_col",
739                ArrowDataType::List(Arc::new(Field::new(
740                    "2d_col",
741                    ArrowDataType::List(Arc::new(Field::new(
742                        "elements",
743                        ArrowDataType::Int64,
744                        true,
745                    ))),
746                    true,
747                ))),
748                true,
749            ),
750            Field::new("i64_col", ArrowDataType::Int64, true),
751        ]);
752
753        let dfschema = DFSchema::try_from(schema)?;
754
755        let input = LogicalPlan::EmptyRelation(EmptyRelation {
756            produce_one_row: false,
757            schema: Arc::new(dfschema),
758        });
759
760        let mut unnest_placeholder_columns = IndexMap::new();
761        let mut inner_projection_exprs = vec![];
762
763        // unnest(unnest(3d_col)) + unnest(unnest(3d_col))
764        let original_expr = unnest(unnest(col("3d_col")))
765            .add(unnest(unnest(col("3d_col"))))
766            .add(col("i64_col"));
767        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
768            &input,
769            &mut unnest_placeholder_columns,
770            &mut inner_projection_exprs,
771            &original_expr,
772        )?;
773        // Only the bottom most unnest exprs are transformed
774        assert_eq!(
775            transformed_exprs,
776            vec![
777                col("__unnest_placeholder(3d_col,depth=2)")
778                    .alias("UNNEST(UNNEST(3d_col))")
779                    .add(
780                        col("__unnest_placeholder(3d_col,depth=2)")
781                            .alias("UNNEST(UNNEST(3d_col))")
782                    )
783                    .add(col("i64_col"))
784            ]
785        );
786        column_unnests_eq(
787            vec![
788                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
789            ],
790            &unnest_placeholder_columns,
791        );
792
793        // Still reference struct_col in original schema but with alias,
794        // to avoid colliding with the projection on the column itself if any
795        assert_eq!(
796            inner_projection_exprs,
797            vec![
798                col("3d_col").alias("__unnest_placeholder(3d_col)"),
799                col("i64_col")
800            ]
801        );
802
803        // unnest(3d_col) as 2d_col
804        let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
805        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
806            &input,
807            &mut unnest_placeholder_columns,
808            &mut inner_projection_exprs,
809            &original_expr_2,
810        )?;
811
812        assert_eq!(
813            transformed_exprs,
814            vec![
815                (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
816                    .alias("2d_col")
817            ]
818        );
819        column_unnests_eq(
820            vec![
821                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]",
822            ],
823            &unnest_placeholder_columns,
824        );
825        // Still reference struct_col in original schema but with alias,
826        // to avoid colliding with the projection on the column itself if any
827        assert_eq!(
828            inner_projection_exprs,
829            vec![
830                col("3d_col").alias("__unnest_placeholder(3d_col)"),
831                col("i64_col")
832            ]
833        );
834
835        Ok(())
836    }
837
838    #[test]
839    fn test_transform_bottom_unnest() -> Result<()> {
840        let schema = Schema::new(vec![
841            Field::new(
842                "struct_col",
843                ArrowDataType::Struct(Fields::from(vec![
844                    Field::new("field1", ArrowDataType::Int32, false),
845                    Field::new("field2", ArrowDataType::Int32, false),
846                ])),
847                false,
848            ),
849            Field::new(
850                "array_col",
851                ArrowDataType::List(Arc::new(Field::new_list_field(
852                    ArrowDataType::Int64,
853                    true,
854                ))),
855                true,
856            ),
857            Field::new("int_col", ArrowDataType::Int32, false),
858        ]);
859
860        let dfschema = DFSchema::try_from(schema)?;
861
862        let input = LogicalPlan::EmptyRelation(EmptyRelation {
863            produce_one_row: false,
864            schema: Arc::new(dfschema),
865        });
866
867        let mut unnest_placeholder_columns = IndexMap::new();
868        let mut inner_projection_exprs = vec![];
869
870        // unnest(struct_col)
871        let original_expr = unnest(col("struct_col"));
872        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
873            &input,
874            &mut unnest_placeholder_columns,
875            &mut inner_projection_exprs,
876            &original_expr,
877        )?;
878        assert_eq!(
879            transformed_exprs,
880            vec![
881                col("__unnest_placeholder(struct_col).field1"),
882                col("__unnest_placeholder(struct_col).field2"),
883            ]
884        );
885        column_unnests_eq(
886            vec!["__unnest_placeholder(struct_col)"],
887            &unnest_placeholder_columns,
888        );
889        // Still reference struct_col in original schema but with alias,
890        // to avoid colliding with the projection on the column itself if any
891        assert_eq!(
892            inner_projection_exprs,
893            vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
894        );
895
896        // unnest(array_col) + 1
897        let original_expr = unnest(col("array_col")).add(lit(1i64));
898        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
899            &input,
900            &mut unnest_placeholder_columns,
901            &mut inner_projection_exprs,
902            &original_expr,
903        )?;
904        column_unnests_eq(
905            vec![
906                "__unnest_placeholder(struct_col)",
907                "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
908            ],
909            &unnest_placeholder_columns,
910        );
911        // Only transform the unnest children
912        assert_eq!(
913            transformed_exprs,
914            vec![
915                col("__unnest_placeholder(array_col,depth=1)")
916                    .alias("UNNEST(array_col)")
917                    .add(lit(1i64))
918            ]
919        );
920
921        // Keep appending to the current vector
922        // Still reference array_col in original schema but with alias,
923        // to avoid colliding with the projection on the column itself if any
924        assert_eq!(
925            inner_projection_exprs,
926            vec![
927                col("struct_col").alias("__unnest_placeholder(struct_col)"),
928                col("array_col").alias("__unnest_placeholder(array_col)")
929            ]
930        );
931
932        Ok(())
933    }
934
935    // Unnest -> field access -> unnest
936    #[test]
937    fn test_transform_non_consecutive_unnests() -> Result<()> {
938        // List of struct
939        // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
940        let schema = Schema::new(vec![
941            Field::new(
942                "struct_list",
943                ArrowDataType::List(Arc::new(Field::new(
944                    "element",
945                    ArrowDataType::Struct(Fields::from(vec![
946                        Field::new(
947                            // list of i64
948                            "subfield1",
949                            ArrowDataType::List(Arc::new(Field::new(
950                                "i64_element",
951                                ArrowDataType::Int64,
952                                true,
953                            ))),
954                            true,
955                        ),
956                        Field::new(
957                            // list of utf8
958                            "subfield2",
959                            ArrowDataType::List(Arc::new(Field::new(
960                                "utf8_element",
961                                ArrowDataType::Utf8,
962                                true,
963                            ))),
964                            true,
965                        ),
966                    ])),
967                    true,
968                ))),
969                true,
970            ),
971            Field::new("int_col", ArrowDataType::Int32, false),
972        ]);
973
974        let dfschema = DFSchema::try_from(schema)?;
975
976        let input = LogicalPlan::EmptyRelation(EmptyRelation {
977            produce_one_row: false,
978            schema: Arc::new(dfschema),
979        });
980
981        let mut unnest_placeholder_columns = IndexMap::new();
982        let mut inner_projection_exprs = vec![];
983
984        // An expr with multiple unnest
985        let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
986        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
987            &input,
988            &mut unnest_placeholder_columns,
989            &mut inner_projection_exprs,
990            &select_expr1,
991        )?;
992        // Only the inner most/ bottom most unnest is transformed
993        assert_eq!(
994            transformed_exprs,
995            vec![unnest(
996                col("__unnest_placeholder(struct_list,depth=1)")
997                    .alias("UNNEST(struct_list)")
998                    .field("subfield1")
999            )]
1000        );
1001
1002        column_unnests_eq(
1003            vec![
1004                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1005            ],
1006            &unnest_placeholder_columns,
1007        );
1008
1009        assert_eq!(
1010            inner_projection_exprs,
1011            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1012        );
1013
1014        // continue rewrite another expr in select
1015        let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
1016        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
1017            &input,
1018            &mut unnest_placeholder_columns,
1019            &mut inner_projection_exprs,
1020            &select_expr2,
1021        )?;
1022        // Only the inner most/ bottom most unnest is transformed
1023        assert_eq!(
1024            transformed_exprs,
1025            vec![unnest(
1026                col("__unnest_placeholder(struct_list,depth=1)")
1027                    .alias("UNNEST(struct_list)")
1028                    .field("subfield2")
1029            )]
1030        );
1031
1032        // unnest place holder columns remain the same
1033        // because expr1 and expr2 derive from the same unnest result
1034        column_unnests_eq(
1035            vec![
1036                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1037            ],
1038            &unnest_placeholder_columns,
1039        );
1040
1041        assert_eq!(
1042            inner_projection_exprs,
1043            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1044        );
1045
1046        Ok(())
1047    }
1048
1049    #[test]
1050    fn test_resolve_positions_to_exprs() -> Result<()> {
1051        let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1052
1053        // Assert 1 resolved as first column in select list
1054        let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1055        assert_eq!(resolved, col("c1"));
1056
1057        // Assert error if index out of select clause bounds
1058        let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1059        assert!(resolved.is_err_and(|e| e.message().contains(
1060            "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1061        )));
1062
1063        let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1064        assert!(resolved.is_err_and(|e| e.message().contains(
1065            "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1066        )));
1067
1068        // Assert expression returned as-is
1069        let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1070        assert_eq!(resolved, lit("text"));
1071
1072        let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1073        assert_eq!(resolved, col("fake"));
1074
1075        Ok(())
1076    }
1077}