datafusion_sql/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL Utility Functions
19
20use std::vec;
21
22use arrow::datatypes::{
23    DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
24};
25use datafusion_common::tree_node::{
26    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29    exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef,
30    Diagnostic, HashMap, Result, ScalarValue,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34    Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38    col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44/// Make a best-effort attempt at resolving all columns in the expression tree
45pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46    expr.clone()
47        .transform_up(|nested_expr| {
48            match nested_expr {
49                Expr::Column(col) => {
50                    let (qualifier, field) =
51                        plan.schema().qualified_field_from_column(&col)?;
52                    Ok(Transformed::yes(Expr::Column(Column::from((
53                        qualifier, field,
54                    )))))
55                }
56                _ => {
57                    // keep recursing
58                    Ok(Transformed::no(nested_expr))
59                }
60            }
61        })
62        .data()
63}
64
65/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
66///
67/// For example, the expression `a + b < 1` would require, as input, the 2
68/// individual columns, `a` and `b`. But, if the base expressions already
69/// contain the `a + b` result, then that may be used in lieu of the `a` and
70/// `b` columns.
71///
72/// This is useful in the context of a query like:
73///
74/// SELECT a + b < 1 ... GROUP BY a + b
75///
76/// where post-aggregation, `a + b` need not be a projection against the
77/// individual columns `a` and `b`, but rather it is a projection against the
78/// `a + b` found in the GROUP BY.
79pub(crate) fn rebase_expr(
80    expr: &Expr,
81    base_exprs: &[Expr],
82    plan: &LogicalPlan,
83) -> Result<Expr> {
84    expr.clone()
85        .transform_down(|nested_expr| {
86            if base_exprs.contains(&nested_expr) {
87                Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88            } else {
89                Ok(Transformed::no(nested_expr))
90            }
91        })
92        .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
97    Projection,
98    Having,
99    Qualify,
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub(crate) enum CheckColumnsSatisfyExprsPurpose {
104    Aggregate(CheckColumnsMustReferenceAggregatePurpose),
105}
106
107impl CheckColumnsSatisfyExprsPurpose {
108    fn message_prefix(&self) -> &'static str {
109        match self {
110            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
111                "Column in SELECT must be in GROUP BY or an aggregate function"
112            }
113            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
114                "Column in HAVING must be in GROUP BY or an aggregate function"
115            }
116            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
117                "Column in QUALIFY must be in GROUP BY or an aggregate function"
118            }
119        }
120    }
121
122    fn diagnostic_message(&self, expr: &Expr) -> String {
123        format!("'{expr}' must appear in GROUP BY clause because it's not an aggregate expression")
124    }
125}
126
127/// Determines if the set of `Expr`'s are a valid projection on the input
128/// `Expr::Column`'s.
129pub(crate) fn check_columns_satisfy_exprs(
130    columns: &[Expr],
131    exprs: &[Expr],
132    purpose: CheckColumnsSatisfyExprsPurpose,
133) -> Result<()> {
134    columns.iter().try_for_each(|c| match c {
135        Expr::Column(_) => Ok(()),
136        _ => internal_err!("Expr::Column are required"),
137    })?;
138    let column_exprs = find_column_exprs(exprs);
139    for e in &column_exprs {
140        match e {
141            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
142                for e in exprs {
143                    check_column_satisfies_expr(columns, e, purpose)?;
144                }
145            }
146            Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
147                for e in exprs {
148                    check_column_satisfies_expr(columns, e, purpose)?;
149                }
150            }
151            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
152                for exprs in lists_of_exprs {
153                    for e in exprs {
154                        check_column_satisfies_expr(columns, e, purpose)?;
155                    }
156                }
157            }
158            _ => check_column_satisfies_expr(columns, e, purpose)?,
159        }
160    }
161    Ok(())
162}
163
164fn check_column_satisfies_expr(
165    columns: &[Expr],
166    expr: &Expr,
167    purpose: CheckColumnsSatisfyExprsPurpose,
168) -> Result<()> {
169    if !columns.contains(expr) {
170        let diagnostic = Diagnostic::new_error(
171            purpose.diagnostic_message(expr),
172            expr.spans().and_then(|spans| spans.first()),
173        )
174        .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
175
176        return plan_err!(
177            "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
178            purpose.message_prefix(),
179            expr,
180            expr_vec_fmt!(columns);
181            diagnostic=diagnostic
182        );
183    }
184    Ok(())
185}
186
187/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
188/// aliasing.
189pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
190    exprs
191        .iter()
192        .filter_map(|expr| match expr {
193            Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
194            _ => None,
195        })
196        .collect::<HashMap<String, Expr>>()
197}
198
199/// Given an expression that's literal int encoding position, lookup the corresponding expression
200/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
201/// otherwise, returns planning error.
202/// If input expression is not an int literal, returns expression as-is.
203pub(crate) fn resolve_positions_to_exprs(
204    expr: Expr,
205    select_exprs: &[Expr],
206) -> Result<Expr> {
207    match expr {
208        // sql_expr_to_logical_expr maps number to i64
209        // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
210        Expr::Literal(ScalarValue::Int64(Some(position)), _)
211            if position > 0_i64 && position <= select_exprs.len() as i64 =>
212        {
213            let index = (position - 1) as usize;
214            let select_expr = &select_exprs[index];
215            Ok(match select_expr {
216                Expr::Alias(Alias { expr, .. }) => *expr.clone(),
217                _ => select_expr.clone(),
218            })
219        }
220        Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
221            "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
222            position, select_exprs.len()
223        ),
224        _ => Ok(expr),
225    }
226}
227
228/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
229/// alias' underlying `Expr`.
230pub(crate) fn resolve_aliases_to_exprs(
231    expr: Expr,
232    aliases: &HashMap<String, Expr>,
233) -> Result<Expr> {
234    expr.transform_up(|nested_expr| match nested_expr {
235        Expr::Column(c) if c.relation.is_none() => {
236            if let Some(aliased_expr) = aliases.get(&c.name) {
237                Ok(Transformed::yes(aliased_expr.clone()))
238            } else {
239                Ok(Transformed::no(Expr::Column(c)))
240            }
241        }
242        _ => Ok(Transformed::no(nested_expr)),
243    })
244    .data()
245}
246
247/// Given a slice of window expressions sharing the same sort key, find their common partition
248/// keys.
249pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
250    let all_partition_keys = window_exprs
251        .iter()
252        .map(|expr| match expr {
253            Expr::WindowFunction(window_fun) => {
254                let WindowFunction {
255                    params: WindowFunctionParams { partition_by, .. },
256                    ..
257                } = window_fun.as_ref();
258                Ok(partition_by)
259            }
260            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
261                Expr::WindowFunction(window_fun) => {
262                    let WindowFunction {
263                        params: WindowFunctionParams { partition_by, .. },
264                        ..
265                    } = window_fun.as_ref();
266                    Ok(partition_by)
267                }
268                expr => exec_err!("Impossibly got non-window expr {expr:?}"),
269            },
270            expr => exec_err!("Impossibly got non-window expr {expr:?}"),
271        })
272        .collect::<Result<Vec<_>>>()?;
273    let result = all_partition_keys
274        .iter()
275        .min_by_key(|s| s.len())
276        .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
277    Ok(result)
278}
279
280/// Returns a validated `DataType` for the specified precision and
281/// scale
282pub(crate) fn make_decimal_type(
283    precision: Option<u64>,
284    scale: Option<u64>,
285) -> Result<DataType> {
286    // postgres like behavior
287    let (precision, scale) = match (precision, scale) {
288        (Some(p), Some(s)) => (p as u8, s as i8),
289        (Some(p), None) => (p as u8, 0),
290        (None, Some(_)) => {
291            return plan_err!("Cannot specify only scale for decimal data type")
292        }
293        (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
294    };
295
296    if precision == 0
297        || precision > DECIMAL256_MAX_PRECISION
298        || scale.unsigned_abs() > precision
299    {
300        plan_err!(
301            "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
302        )
303    } else if precision > DECIMAL128_MAX_PRECISION
304        && precision <= DECIMAL256_MAX_PRECISION
305    {
306        Ok(DataType::Decimal256(precision, scale))
307    } else {
308        Ok(DataType::Decimal128(precision, scale))
309    }
310}
311
312/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
313pub(crate) fn normalize_ident(id: Ident) -> String {
314    match id.quote_style {
315        Some(_) => id.value,
316        None => id.value.to_ascii_lowercase(),
317    }
318}
319
320pub(crate) fn value_to_string(value: &Value) -> Option<String> {
321    match value {
322        Value::SingleQuotedString(s) => Some(s.to_string()),
323        Value::DollarQuotedString(s) => Some(s.to_string()),
324        Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
325        Value::UnicodeStringLiteral(s) => Some(s.to_string()),
326        Value::EscapedStringLiteral(s) => Some(s.to_string()),
327        Value::DoubleQuotedString(_)
328        | Value::NationalStringLiteral(_)
329        | Value::SingleQuotedByteStringLiteral(_)
330        | Value::DoubleQuotedByteStringLiteral(_)
331        | Value::TripleSingleQuotedString(_)
332        | Value::TripleDoubleQuotedString(_)
333        | Value::TripleSingleQuotedByteStringLiteral(_)
334        | Value::TripleDoubleQuotedByteStringLiteral(_)
335        | Value::SingleQuotedRawStringLiteral(_)
336        | Value::DoubleQuotedRawStringLiteral(_)
337        | Value::TripleSingleQuotedRawStringLiteral(_)
338        | Value::TripleDoubleQuotedRawStringLiteral(_)
339        | Value::HexStringLiteral(_)
340        | Value::Null
341        | Value::Placeholder(_) => None,
342    }
343}
344
345pub(crate) fn rewrite_recursive_unnests_bottom_up(
346    input: &LogicalPlan,
347    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
348    inner_projection_exprs: &mut Vec<Expr>,
349    original_exprs: &[Expr],
350) -> Result<Vec<Expr>> {
351    Ok(original_exprs
352        .iter()
353        .map(|expr| {
354            rewrite_recursive_unnest_bottom_up(
355                input,
356                unnest_placeholder_columns,
357                inner_projection_exprs,
358                expr,
359            )
360        })
361        .collect::<Result<Vec<_>>>()?
362        .into_iter()
363        .flatten()
364        .collect::<Vec<_>>())
365}
366
367pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
368
369/*
370This is only usedful when used with transform down up
371A full example of how the transformation works:
372 */
373struct RecursiveUnnestRewriter<'a> {
374    input_schema: &'a DFSchemaRef,
375    root_expr: &'a Expr,
376    // Useful to detect which child expr is a part of/ not a part of unnest operation
377    top_most_unnest: Option<Unnest>,
378    consecutive_unnest: Vec<Option<Unnest>>,
379    inner_projection_exprs: &'a mut Vec<Expr>,
380    columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
381    transformed_root_exprs: Option<Vec<Expr>>,
382}
383impl RecursiveUnnestRewriter<'_> {
384    /// This struct stores the history of expr
385    /// during its tree-traversal with a notation of
386    /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
387    /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
388    ///
389    /// The first item will be the inner most expr
390    fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
391        self.consecutive_unnest
392            .iter()
393            .rev()
394            .skip_while(|item| item.is_none())
395            .take_while(|item| item.is_some())
396            .to_owned()
397            .cloned()
398            .map(|item| item.unwrap())
399            .collect()
400    }
401
402    fn transform(
403        &mut self,
404        level: usize,
405        alias_name: String,
406        expr_in_unnest: &Expr,
407        struct_allowed: bool,
408    ) -> Result<Vec<Expr>> {
409        let inner_expr_name = expr_in_unnest.schema_name().to_string();
410
411        // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
412        // inside unnest execution, each column inside the inner projection
413        // will be transformed into new columns. Thus we need to keep track of these placeholding column names
414        let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
415        let post_unnest_name =
416            format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
417        // This is due to the fact that unnest transformation should keep the original
418        // column name as is, to comply with group by and order by
419        let placeholder_column = Column::from_name(placeholder_name.clone());
420
421        let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?;
422
423        match data_type {
424            DataType::Struct(inner_fields) => {
425                if !struct_allowed {
426                    return internal_err!("unnest on struct can only be applied at the root level of select expression");
427                }
428                push_projection_dedupl(
429                    self.inner_projection_exprs,
430                    expr_in_unnest.clone().alias(placeholder_name.clone()),
431                );
432                self.columns_unnestings
433                    .insert(Column::from_name(placeholder_name.clone()), None);
434                Ok(
435                    get_struct_unnested_columns(&placeholder_name, &inner_fields)
436                        .into_iter()
437                        .map(Expr::Column)
438                        .collect(),
439                )
440            }
441            DataType::List(_)
442            | DataType::FixedSizeList(_, _)
443            | DataType::LargeList(_) => {
444                push_projection_dedupl(
445                    self.inner_projection_exprs,
446                    expr_in_unnest.clone().alias(placeholder_name.clone()),
447                );
448
449                let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
450                let list_unnesting = self
451                    .columns_unnestings
452                    .entry(placeholder_column)
453                    .or_insert(Some(vec![]));
454                let unnesting = ColumnUnnestList {
455                    output_column: Column::from_name(post_unnest_name),
456                    depth: level,
457                };
458                let list_unnestings = list_unnesting.as_mut().unwrap();
459                if !list_unnestings.contains(&unnesting) {
460                    list_unnestings.push(unnesting);
461                }
462                Ok(vec![post_unnest_expr])
463            }
464            _ => {
465                internal_err!("unnest on non-list or struct type is not supported")
466            }
467        }
468    }
469}
470
471impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
472    type Node = Expr;
473
474    /// This downward traversal needs to keep track of:
475    /// - Whether or not some unnest expr has been visited from the top util the current node
476    /// - If some unnest expr has been visited, maintain a stack of such information, this
477    ///   is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
478    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
479        if let Expr::Unnest(ref unnest_expr) = expr {
480            let (data_type, _) =
481                unnest_expr.expr.data_type_and_nullable(self.input_schema)?;
482            self.consecutive_unnest.push(Some(unnest_expr.clone()));
483            // if expr inside unnest is a struct, do not consider
484            // the next unnest as consecutive unnest (if any)
485            // meaning unnest(unnest(struct_arr_col)) can't
486            // be interpreted as unnest(struct_arr_col, depth:=2)
487            // but has to be split into multiple unnest logical plan instead
488            // a.k.a:
489            // - unnest(struct_col)
490            //      unnest(struct_arr_col) as struct_col
491
492            if let DataType::Struct(_) = data_type {
493                self.consecutive_unnest.push(None);
494            }
495            if self.top_most_unnest.is_none() {
496                self.top_most_unnest = Some(unnest_expr.clone());
497            }
498
499            Ok(Transformed::no(expr))
500        } else {
501            self.consecutive_unnest.push(None);
502            Ok(Transformed::no(expr))
503        }
504    }
505
506    /// The rewriting only happens when the traversal has reached the top-most unnest expr
507    /// within a sequence of consecutive unnest exprs node
508    ///
509    /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
510    /// ```text
511    ///                         ┌──────────────────┐
512    ///                         │    binaryexpr    │
513    ///                         │                  │
514    ///                         └──────────────────┘
515    ///                f_down  / /            │ │
516    ///                       / / f_up        │ │
517    ///                      / /        f_down│ │f_up
518    ///                  unnest               │ │
519    ///                                       │ │
520    ///       f_down  / / f_up(rewriting)     │ │
521    ///              / /
522    ///             / /                      unnest
523    ///         unnest
524    ///                           f_down  / / f_up(rewriting)
525    /// f_down / /f_up                   / /
526    ///       / /                       / /
527    ///      / /                    unnest
528    ///   column1
529    ///                     f_down / /f_up
530    ///                           / /
531    ///                          / /
532    ///                       column2
533    /// ```
534    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
535        if let Expr::Unnest(ref traversing_unnest) = expr {
536            if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
537                self.top_most_unnest = None;
538            }
539            // Find inside consecutive_unnest, the sequence of continuous unnest exprs
540
541            // Get the latest consecutive unnest exprs
542            // and check if current upward traversal is the returning to the root expr
543            // for example given a expr `unnest(unnest(col))` then the traversal happens like:
544            // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
545            // the result of such traversal is unnest(col, depth:=2)
546            let unnest_stack = self.get_latest_consecutive_unnest();
547
548            // This traversal has reached the top most unnest again
549            // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
550            // -> Unnest(2nd) -> Unnest(top) a.k.a here
551            // Thus
552            // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
553            if traversing_unnest == unnest_stack.last().unwrap() {
554                let most_inner = unnest_stack.first().unwrap();
555                let inner_expr = most_inner.expr.as_ref();
556                // unnest(unnest(struct_arr_col)) is not allow to be done recursively
557                // it needs to be split into multiple unnest logical plan
558                // unnest(struct_arr)
559                //  unnest(struct_arr_col) as struct_arr
560                // instead of unnest(struct_arr_col, depth = 2)
561
562                let unnest_recursion = unnest_stack.len();
563                let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1;
564
565                let mut transformed_exprs = self.transform(
566                    unnest_recursion,
567                    expr.schema_name().to_string(),
568                    inner_expr,
569                    struct_allowed,
570                )?;
571                if struct_allowed {
572                    self.transformed_root_exprs = Some(transformed_exprs.clone());
573                }
574                return Ok(Transformed::new(
575                    transformed_exprs.swap_remove(0),
576                    true,
577                    TreeNodeRecursion::Continue,
578                ));
579            }
580        } else {
581            self.consecutive_unnest.push(None);
582        }
583
584        // For column exprs that are not descendants of any unnest node
585        // retain their projection
586        // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
587        // this condition can be checked by maintaining an Option<top most unnest>
588        if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
589            push_projection_dedupl(self.inner_projection_exprs, expr.clone());
590        }
591
592        Ok(Transformed::no(expr))
593    }
594}
595
596fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
597    let schema_name = expr.schema_name().to_string();
598    if !projection
599        .iter()
600        .any(|e| e.schema_name().to_string() == schema_name)
601    {
602        projection.push(expr);
603    }
604}
605/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
606/// Given an expression which contains unnest expr as one of its children,
607/// Try transform depends on unnest type
608/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
609/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
610///
611/// The transformed exprs will be used in the outer projection
612/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
613/// is done only for the bottom expression
614pub(crate) fn rewrite_recursive_unnest_bottom_up(
615    input: &LogicalPlan,
616    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
617    inner_projection_exprs: &mut Vec<Expr>,
618    original_expr: &Expr,
619) -> Result<Vec<Expr>> {
620    let mut rewriter = RecursiveUnnestRewriter {
621        input_schema: input.schema(),
622        root_expr: original_expr,
623        top_most_unnest: None,
624        consecutive_unnest: vec![],
625        inner_projection_exprs,
626        columns_unnestings: unnest_placeholder_columns,
627        transformed_root_exprs: None,
628    };
629
630    // This transformation is only done for list unnest
631    // struct unnest is done at the root level, and at the later stage
632    // because the syntax of TreeNode only support transform into 1 Expr, while
633    // Unnest struct will be transformed into multiple Exprs
634    // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
635    //
636    // The transformation looks like:
637    // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
638    // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
639    let Transformed {
640        data: transformed_expr,
641        transformed,
642        tnr: _,
643    } = original_expr.clone().rewrite(&mut rewriter)?;
644
645    if !transformed {
646        // TODO: remove the next line after `Expr::Wildcard` is removed
647        #[expect(deprecated)]
648        if matches!(&transformed_expr, Expr::Column(_))
649            || matches!(&transformed_expr, Expr::Wildcard { .. })
650        {
651            push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
652            Ok(vec![transformed_expr])
653        } else {
654            // We need to evaluate the expr in the inner projection,
655            // outer projection just select its name
656            let column_name = transformed_expr.schema_name().to_string();
657            push_projection_dedupl(inner_projection_exprs, transformed_expr);
658            Ok(vec![Expr::Column(Column::from_name(column_name))])
659        }
660    } else {
661        if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
662            return Ok(transformed_root_exprs);
663        }
664        Ok(vec![transformed_expr])
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use std::{ops::Add, sync::Arc};
671
672    use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
673    use datafusion_common::{Column, DFSchema, Result};
674    use datafusion_expr::{
675        col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan,
676    };
677    use datafusion_functions::core::expr_ext::FieldAccessor;
678    use datafusion_functions_aggregate::expr_fn::count;
679
680    use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
681    use indexmap::IndexMap;
682
683    fn column_unnests_eq(
684        l: Vec<&str>,
685        r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
686    ) {
687        let r_formatted: Vec<String> = r
688            .iter()
689            .map(|i| match i.1 {
690                None => format!("{}", i.0),
691                Some(vec) => format!(
692                    "{}=>[{}]",
693                    i.0,
694                    vec.iter()
695                        .map(|i| format!("{i}"))
696                        .collect::<Vec<String>>()
697                        .join(", ")
698                ),
699            })
700            .collect();
701        let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
702        assert_eq!(l_formatted, r_formatted);
703    }
704
705    #[test]
706    fn test_transform_bottom_unnest_recursive() -> Result<()> {
707        let schema = Schema::new(vec![
708            Field::new(
709                "3d_col",
710                ArrowDataType::List(Arc::new(Field::new(
711                    "2d_col",
712                    ArrowDataType::List(Arc::new(Field::new(
713                        "elements",
714                        ArrowDataType::Int64,
715                        true,
716                    ))),
717                    true,
718                ))),
719                true,
720            ),
721            Field::new("i64_col", ArrowDataType::Int64, true),
722        ]);
723
724        let dfschema = DFSchema::try_from(schema)?;
725
726        let input = LogicalPlan::EmptyRelation(EmptyRelation {
727            produce_one_row: false,
728            schema: Arc::new(dfschema),
729        });
730
731        let mut unnest_placeholder_columns = IndexMap::new();
732        let mut inner_projection_exprs = vec![];
733
734        // unnest(unnest(3d_col)) + unnest(unnest(3d_col))
735        let original_expr = unnest(unnest(col("3d_col")))
736            .add(unnest(unnest(col("3d_col"))))
737            .add(col("i64_col"));
738        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
739            &input,
740            &mut unnest_placeholder_columns,
741            &mut inner_projection_exprs,
742            &original_expr,
743        )?;
744        // Only the bottom most unnest exprs are transformed
745        assert_eq!(
746            transformed_exprs,
747            vec![col("__unnest_placeholder(3d_col,depth=2)")
748                .alias("UNNEST(UNNEST(3d_col))")
749                .add(
750                    col("__unnest_placeholder(3d_col,depth=2)")
751                        .alias("UNNEST(UNNEST(3d_col))")
752                )
753                .add(col("i64_col"))]
754        );
755        column_unnests_eq(
756            vec![
757                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
758            ],
759            &unnest_placeholder_columns,
760        );
761
762        // Still reference struct_col in original schema but with alias,
763        // to avoid colliding with the projection on the column itself if any
764        assert_eq!(
765            inner_projection_exprs,
766            vec![
767                col("3d_col").alias("__unnest_placeholder(3d_col)"),
768                col("i64_col")
769            ]
770        );
771
772        // unnest(3d_col) as 2d_col
773        let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
774        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
775            &input,
776            &mut unnest_placeholder_columns,
777            &mut inner_projection_exprs,
778            &original_expr_2,
779        )?;
780
781        assert_eq!(
782            transformed_exprs,
783            vec![
784                (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
785                    .alias("2d_col")
786            ]
787        );
788        column_unnests_eq(
789            vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"],
790            &unnest_placeholder_columns,
791        );
792        // Still reference struct_col in original schema but with alias,
793        // to avoid colliding with the projection on the column itself if any
794        assert_eq!(
795            inner_projection_exprs,
796            vec![
797                col("3d_col").alias("__unnest_placeholder(3d_col)"),
798                col("i64_col")
799            ]
800        );
801
802        Ok(())
803    }
804
805    #[test]
806    fn test_transform_bottom_unnest() -> Result<()> {
807        let schema = Schema::new(vec![
808            Field::new(
809                "struct_col",
810                ArrowDataType::Struct(Fields::from(vec![
811                    Field::new("field1", ArrowDataType::Int32, false),
812                    Field::new("field2", ArrowDataType::Int32, false),
813                ])),
814                false,
815            ),
816            Field::new(
817                "array_col",
818                ArrowDataType::List(Arc::new(Field::new_list_field(
819                    ArrowDataType::Int64,
820                    true,
821                ))),
822                true,
823            ),
824            Field::new("int_col", ArrowDataType::Int32, false),
825        ]);
826
827        let dfschema = DFSchema::try_from(schema)?;
828
829        let input = LogicalPlan::EmptyRelation(EmptyRelation {
830            produce_one_row: false,
831            schema: Arc::new(dfschema),
832        });
833
834        let mut unnest_placeholder_columns = IndexMap::new();
835        let mut inner_projection_exprs = vec![];
836
837        // unnest(struct_col)
838        let original_expr = unnest(col("struct_col"));
839        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
840            &input,
841            &mut unnest_placeholder_columns,
842            &mut inner_projection_exprs,
843            &original_expr,
844        )?;
845        assert_eq!(
846            transformed_exprs,
847            vec![
848                col("__unnest_placeholder(struct_col).field1"),
849                col("__unnest_placeholder(struct_col).field2"),
850            ]
851        );
852        column_unnests_eq(
853            vec!["__unnest_placeholder(struct_col)"],
854            &unnest_placeholder_columns,
855        );
856        // Still reference struct_col in original schema but with alias,
857        // to avoid colliding with the projection on the column itself if any
858        assert_eq!(
859            inner_projection_exprs,
860            vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
861        );
862
863        // unnest(array_col) + 1
864        let original_expr = unnest(col("array_col")).add(lit(1i64));
865        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
866            &input,
867            &mut unnest_placeholder_columns,
868            &mut inner_projection_exprs,
869            &original_expr,
870        )?;
871        column_unnests_eq(
872            vec![
873                "__unnest_placeholder(struct_col)",
874                "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
875            ],
876            &unnest_placeholder_columns,
877        );
878        // Only transform the unnest children
879        assert_eq!(
880            transformed_exprs,
881            vec![col("__unnest_placeholder(array_col,depth=1)")
882                .alias("UNNEST(array_col)")
883                .add(lit(1i64))]
884        );
885
886        // Keep appending to the current vector
887        // Still reference array_col in original schema but with alias,
888        // to avoid colliding with the projection on the column itself if any
889        assert_eq!(
890            inner_projection_exprs,
891            vec![
892                col("struct_col").alias("__unnest_placeholder(struct_col)"),
893                col("array_col").alias("__unnest_placeholder(array_col)")
894            ]
895        );
896
897        Ok(())
898    }
899
900    // Unnest -> field access -> unnest
901    #[test]
902    fn test_transform_non_consecutive_unnests() -> Result<()> {
903        // List of struct
904        // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
905        let schema = Schema::new(vec![
906            Field::new(
907                "struct_list",
908                ArrowDataType::List(Arc::new(Field::new(
909                    "element",
910                    ArrowDataType::Struct(Fields::from(vec![
911                        Field::new(
912                            // list of i64
913                            "subfield1",
914                            ArrowDataType::List(Arc::new(Field::new(
915                                "i64_element",
916                                ArrowDataType::Int64,
917                                true,
918                            ))),
919                            true,
920                        ),
921                        Field::new(
922                            // list of utf8
923                            "subfield2",
924                            ArrowDataType::List(Arc::new(Field::new(
925                                "utf8_element",
926                                ArrowDataType::Utf8,
927                                true,
928                            ))),
929                            true,
930                        ),
931                    ])),
932                    true,
933                ))),
934                true,
935            ),
936            Field::new("int_col", ArrowDataType::Int32, false),
937        ]);
938
939        let dfschema = DFSchema::try_from(schema)?;
940
941        let input = LogicalPlan::EmptyRelation(EmptyRelation {
942            produce_one_row: false,
943            schema: Arc::new(dfschema),
944        });
945
946        let mut unnest_placeholder_columns = IndexMap::new();
947        let mut inner_projection_exprs = vec![];
948
949        // An expr with multiple unnest
950        let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
951        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
952            &input,
953            &mut unnest_placeholder_columns,
954            &mut inner_projection_exprs,
955            &select_expr1,
956        )?;
957        // Only the inner most/ bottom most unnest is transformed
958        assert_eq!(
959            transformed_exprs,
960            vec![unnest(
961                col("__unnest_placeholder(struct_list,depth=1)")
962                    .alias("UNNEST(struct_list)")
963                    .field("subfield1")
964            )]
965        );
966
967        column_unnests_eq(
968            vec![
969                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
970            ],
971            &unnest_placeholder_columns,
972        );
973
974        assert_eq!(
975            inner_projection_exprs,
976            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
977        );
978
979        // continue rewrite another expr in select
980        let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
981        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
982            &input,
983            &mut unnest_placeholder_columns,
984            &mut inner_projection_exprs,
985            &select_expr2,
986        )?;
987        // Only the inner most/ bottom most unnest is transformed
988        assert_eq!(
989            transformed_exprs,
990            vec![unnest(
991                col("__unnest_placeholder(struct_list,depth=1)")
992                    .alias("UNNEST(struct_list)")
993                    .field("subfield2")
994            )]
995        );
996
997        // unnest place holder columns remain the same
998        // because expr1 and expr2 derive from the same unnest result
999        column_unnests_eq(
1000            vec![
1001                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1002            ],
1003            &unnest_placeholder_columns,
1004        );
1005
1006        assert_eq!(
1007            inner_projection_exprs,
1008            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1009        );
1010
1011        Ok(())
1012    }
1013
1014    #[test]
1015    fn test_resolve_positions_to_exprs() -> Result<()> {
1016        let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1017
1018        // Assert 1 resolved as first column in select list
1019        let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1020        assert_eq!(resolved, col("c1"));
1021
1022        // Assert error if index out of select clause bounds
1023        let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1024        assert!(resolved.is_err_and(|e| e.message().contains(
1025            "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1026        )));
1027
1028        let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1029        assert!(resolved.is_err_and(|e| e.message().contains(
1030            "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1031        )));
1032
1033        // Assert expression returned as-is
1034        let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1035        assert_eq!(resolved, lit("text"));
1036
1037        let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1038        assert_eq!(resolved, col("fake"));
1039
1040        Ok(())
1041    }
1042}