datafusion_sql/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL Utility Functions
19
20use std::vec;
21
22use arrow::datatypes::{
23    DECIMAL_DEFAULT_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType,
24};
25use datafusion_common::tree_node::{
26    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29    Column, DFSchemaRef, Diagnostic, HashMap, Result, ScalarValue,
30    assert_or_internal_err, exec_datafusion_err, exec_err, internal_err, plan_err,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34    Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38    ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, col, expr_vec_fmt,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44/// Make a best-effort attempt at resolving all columns in the expression tree
45pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46    expr.clone()
47        .transform_up(|nested_expr| {
48            match nested_expr {
49                Expr::Column(col) => {
50                    let (qualifier, field) =
51                        plan.schema().qualified_field_from_column(&col)?;
52                    Ok(Transformed::yes(Expr::Column(Column::from((
53                        qualifier, field,
54                    )))))
55                }
56                _ => {
57                    // keep recursing
58                    Ok(Transformed::no(nested_expr))
59                }
60            }
61        })
62        .data()
63}
64
65/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
66///
67/// For example, the expression `a + b < 1` would require, as input, the 2
68/// individual columns, `a` and `b`. But, if the base expressions already
69/// contain the `a + b` result, then that may be used in lieu of the `a` and
70/// `b` columns.
71///
72/// This is useful in the context of a query like:
73///
74/// SELECT a + b < 1 ... GROUP BY a + b
75///
76/// where post-aggregation, `a + b` need not be a projection against the
77/// individual columns `a` and `b`, but rather it is a projection against the
78/// `a + b` found in the GROUP BY.
79pub(crate) fn rebase_expr(
80    expr: &Expr,
81    base_exprs: &[Expr],
82    plan: &LogicalPlan,
83) -> Result<Expr> {
84    expr.clone()
85        .transform_down(|nested_expr| {
86            if base_exprs.contains(&nested_expr) {
87                Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88            } else {
89                Ok(Transformed::no(nested_expr))
90            }
91        })
92        .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
97    Projection,
98    Having,
99    Qualify,
100    OrderBy,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub(crate) enum CheckColumnsSatisfyExprsPurpose {
105    Aggregate(CheckColumnsMustReferenceAggregatePurpose),
106}
107
108impl CheckColumnsSatisfyExprsPurpose {
109    fn message_prefix(&self) -> &'static str {
110        match self {
111            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
112                "Column in SELECT must be in GROUP BY or an aggregate function"
113            }
114            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
115                "Column in HAVING must be in GROUP BY or an aggregate function"
116            }
117            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
118                "Column in QUALIFY must be in GROUP BY or an aggregate function"
119            }
120            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::OrderBy) => {
121                "Column in ORDER BY must be in GROUP BY or an aggregate function"
122            }
123        }
124    }
125
126    fn diagnostic_message(&self, expr: &Expr) -> String {
127        format!(
128            "'{expr}' must appear in GROUP BY clause because it's not an aggregate expression"
129        )
130    }
131}
132
133/// Determines if the set of `Expr`'s are a valid projection on the input
134/// `Expr::Column`'s.
135pub(crate) fn check_columns_satisfy_exprs(
136    columns: &[Expr],
137    exprs: &[Expr],
138    purpose: CheckColumnsSatisfyExprsPurpose,
139) -> Result<()> {
140    columns.iter().try_for_each(|c| match c {
141        Expr::Column(_) => Ok(()),
142        _ => internal_err!("Expr::Column are required"),
143    })?;
144    let column_exprs = find_column_exprs(exprs);
145    for e in &column_exprs {
146        match e {
147            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
148                for e in exprs {
149                    check_column_satisfies_expr(columns, e, purpose)?;
150                }
151            }
152            Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
153                for e in exprs {
154                    check_column_satisfies_expr(columns, e, purpose)?;
155                }
156            }
157            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
158                for exprs in lists_of_exprs {
159                    for e in exprs {
160                        check_column_satisfies_expr(columns, e, purpose)?;
161                    }
162                }
163            }
164            _ => check_column_satisfies_expr(columns, e, purpose)?,
165        }
166    }
167    Ok(())
168}
169
170fn check_column_satisfies_expr(
171    columns: &[Expr],
172    expr: &Expr,
173    purpose: CheckColumnsSatisfyExprsPurpose,
174) -> Result<()> {
175    if !columns.contains(expr) {
176        let diagnostic = Diagnostic::new_error(
177            purpose.diagnostic_message(expr),
178            expr.spans().and_then(|spans| spans.first()),
179        )
180        .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
181
182        return plan_err!(
183            "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
184            purpose.message_prefix(),
185            expr,
186            expr_vec_fmt!(columns);
187            diagnostic=diagnostic
188        );
189    }
190    Ok(())
191}
192
193/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
194/// aliasing.
195pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
196    exprs
197        .iter()
198        .filter_map(|expr| match expr {
199            Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
200            _ => None,
201        })
202        .collect::<HashMap<String, Expr>>()
203}
204
205/// Given an expression that's literal int encoding position, lookup the corresponding expression
206/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
207/// otherwise, returns planning error.
208/// If input expression is not an int literal, returns expression as-is.
209pub(crate) fn resolve_positions_to_exprs(
210    expr: Expr,
211    select_exprs: &[Expr],
212) -> Result<Expr> {
213    match expr {
214        // sql_expr_to_logical_expr maps number to i64
215        // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
216        Expr::Literal(ScalarValue::Int64(Some(position)), _)
217            if position > 0_i64 && position <= select_exprs.len() as i64 =>
218        {
219            let index = (position - 1) as usize;
220            let select_expr = &select_exprs[index];
221            Ok(match select_expr {
222                Expr::Alias(Alias { expr, .. }) => *expr.clone(),
223                _ => select_expr.clone(),
224            })
225        }
226        Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
227            "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
228            position,
229            select_exprs.len()
230        ),
231        _ => Ok(expr),
232    }
233}
234
235/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
236/// alias' underlying `Expr`.
237pub(crate) fn resolve_aliases_to_exprs(
238    expr: Expr,
239    aliases: &HashMap<String, Expr>,
240) -> Result<Expr> {
241    expr.transform_up(|nested_expr| match nested_expr {
242        Expr::Column(c) if c.relation.is_none() => {
243            if let Some(aliased_expr) = aliases.get(&c.name) {
244                Ok(Transformed::yes(aliased_expr.clone()))
245            } else {
246                Ok(Transformed::no(Expr::Column(c)))
247            }
248        }
249        _ => Ok(Transformed::no(nested_expr)),
250    })
251    .data()
252}
253
254/// Given a slice of window expressions sharing the same sort key, find their common partition
255/// keys.
256pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
257    let all_partition_keys = window_exprs
258        .iter()
259        .map(|expr| match expr {
260            Expr::WindowFunction(window_fun) => {
261                let WindowFunction {
262                    params: WindowFunctionParams { partition_by, .. },
263                    ..
264                } = window_fun.as_ref();
265                Ok(partition_by)
266            }
267            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
268                Expr::WindowFunction(window_fun) => {
269                    let WindowFunction {
270                        params: WindowFunctionParams { partition_by, .. },
271                        ..
272                    } = window_fun.as_ref();
273                    Ok(partition_by)
274                }
275                expr => exec_err!("Impossibly got non-window expr {expr:?}"),
276            },
277            expr => exec_err!("Impossibly got non-window expr {expr:?}"),
278        })
279        .collect::<Result<Vec<_>>>()?;
280    let result = all_partition_keys
281        .iter()
282        .min_by_key(|s| s.len())
283        .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
284    Ok(result)
285}
286
287/// Returns a validated `DataType` for the specified precision and
288/// scale
289pub(crate) fn make_decimal_type(
290    precision: Option<u64>,
291    scale: Option<u64>,
292) -> Result<DataType> {
293    // postgres like behavior
294    let (precision, scale) = match (precision, scale) {
295        (Some(p), Some(s)) => (p as u8, s as i8),
296        (Some(p), None) => (p as u8, 0),
297        (None, Some(_)) => {
298            return plan_err!("Cannot specify only scale for decimal data type");
299        }
300        (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
301    };
302
303    if precision == 0
304        || precision > DECIMAL256_MAX_PRECISION
305        || scale.unsigned_abs() > precision
306    {
307        plan_err!(
308            "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
309        )
310    } else if precision > DECIMAL128_MAX_PRECISION
311        && precision <= DECIMAL256_MAX_PRECISION
312    {
313        Ok(DataType::Decimal256(precision, scale))
314    } else {
315        Ok(DataType::Decimal128(precision, scale))
316    }
317}
318
319/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
320pub(crate) fn normalize_ident(id: Ident) -> String {
321    match id.quote_style {
322        Some(_) => id.value,
323        None => id.value.to_ascii_lowercase(),
324    }
325}
326
327pub(crate) fn value_to_string(value: &Value) -> Option<String> {
328    match value {
329        Value::SingleQuotedString(s) => Some(s.to_string()),
330        Value::DollarQuotedString(s) => Some(s.to_string()),
331        Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
332        Value::UnicodeStringLiteral(s) => Some(s.to_string()),
333        Value::EscapedStringLiteral(s) => Some(s.to_string()),
334        Value::DoubleQuotedString(_)
335        | Value::NationalStringLiteral(_)
336        | Value::SingleQuotedByteStringLiteral(_)
337        | Value::DoubleQuotedByteStringLiteral(_)
338        | Value::TripleSingleQuotedString(_)
339        | Value::TripleDoubleQuotedString(_)
340        | Value::TripleSingleQuotedByteStringLiteral(_)
341        | Value::TripleDoubleQuotedByteStringLiteral(_)
342        | Value::SingleQuotedRawStringLiteral(_)
343        | Value::DoubleQuotedRawStringLiteral(_)
344        | Value::TripleSingleQuotedRawStringLiteral(_)
345        | Value::TripleDoubleQuotedRawStringLiteral(_)
346        | Value::HexStringLiteral(_)
347        | Value::Null
348        | Value::Placeholder(_) => None,
349    }
350}
351
352pub(crate) fn rewrite_recursive_unnests_bottom_up(
353    input: &LogicalPlan,
354    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
355    inner_projection_exprs: &mut Vec<Expr>,
356    original_exprs: &[Expr],
357) -> Result<Vec<Expr>> {
358    Ok(original_exprs
359        .iter()
360        .map(|expr| {
361            rewrite_recursive_unnest_bottom_up(
362                input,
363                unnest_placeholder_columns,
364                inner_projection_exprs,
365                expr,
366            )
367        })
368        .collect::<Result<Vec<_>>>()?
369        .into_iter()
370        .flatten()
371        .collect::<Vec<_>>())
372}
373
374pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
375
376/*
377This is only usedful when used with transform down up
378A full example of how the transformation works:
379 */
380struct RecursiveUnnestRewriter<'a> {
381    input_schema: &'a DFSchemaRef,
382    root_expr: &'a Expr,
383    // Useful to detect which child expr is a part of/ not a part of unnest operation
384    top_most_unnest: Option<Unnest>,
385    consecutive_unnest: Vec<Option<Unnest>>,
386    inner_projection_exprs: &'a mut Vec<Expr>,
387    columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
388    transformed_root_exprs: Option<Vec<Expr>>,
389}
390impl RecursiveUnnestRewriter<'_> {
391    /// This struct stores the history of expr
392    /// during its tree-traversal with a notation of
393    /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
394    /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
395    ///
396    /// The first item will be the inner most expr
397    fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
398        self.consecutive_unnest
399            .iter()
400            .rev()
401            .skip_while(|item| item.is_none())
402            .take_while(|item| item.is_some())
403            .to_owned()
404            .cloned()
405            .map(|item| item.unwrap())
406            .collect()
407    }
408
409    fn transform(
410        &mut self,
411        level: usize,
412        alias_name: String,
413        expr_in_unnest: &Expr,
414        struct_allowed: bool,
415    ) -> Result<Vec<Expr>> {
416        let inner_expr_name = expr_in_unnest.schema_name().to_string();
417
418        // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
419        // inside unnest execution, each column inside the inner projection
420        // will be transformed into new columns. Thus we need to keep track of these placeholding column names
421        let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
422        let post_unnest_name =
423            format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
424        // This is due to the fact that unnest transformation should keep the original
425        // column name as is, to comply with group by and order by
426        let placeholder_column = Column::from_name(placeholder_name.clone());
427        let field = expr_in_unnest.to_field(self.input_schema)?.1;
428        let data_type = field.data_type();
429
430        match data_type {
431            DataType::Struct(inner_fields) => {
432                assert_or_internal_err!(
433                    struct_allowed,
434                    "unnest on struct can only be applied at the root level of select expression"
435                );
436                push_projection_dedupl(
437                    self.inner_projection_exprs,
438                    expr_in_unnest.clone().alias(placeholder_name.clone()),
439                );
440                self.columns_unnestings
441                    .insert(Column::from_name(placeholder_name.clone()), None);
442                Ok(get_struct_unnested_columns(&placeholder_name, inner_fields)
443                    .into_iter()
444                    .map(Expr::Column)
445                    .collect())
446            }
447            DataType::List(_)
448            | DataType::FixedSizeList(_, _)
449            | DataType::LargeList(_) => {
450                push_projection_dedupl(
451                    self.inner_projection_exprs,
452                    expr_in_unnest.clone().alias(placeholder_name.clone()),
453                );
454
455                let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
456                let list_unnesting = self
457                    .columns_unnestings
458                    .entry(placeholder_column)
459                    .or_insert(Some(vec![]));
460                let unnesting = ColumnUnnestList {
461                    output_column: Column::from_name(post_unnest_name),
462                    depth: level,
463                };
464                let list_unnestings = list_unnesting.as_mut().unwrap();
465                if !list_unnestings.contains(&unnesting) {
466                    list_unnestings.push(unnesting);
467                }
468                Ok(vec![post_unnest_expr])
469            }
470            _ => {
471                internal_err!("unnest on non-list or struct type is not supported")
472            }
473        }
474    }
475}
476
477impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
478    type Node = Expr;
479
480    /// This downward traversal needs to keep track of:
481    /// - Whether or not some unnest expr has been visited from the top util the current node
482    /// - If some unnest expr has been visited, maintain a stack of such information, this
483    ///   is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
484    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
485        if let Expr::Unnest(ref unnest_expr) = expr {
486            let field = unnest_expr.expr.to_field(self.input_schema)?.1;
487            let data_type = field.data_type();
488            self.consecutive_unnest.push(Some(unnest_expr.clone()));
489            // if expr inside unnest is a struct, do not consider
490            // the next unnest as consecutive unnest (if any)
491            // meaning unnest(unnest(struct_arr_col)) can't
492            // be interpreted as unnest(struct_arr_col, depth:=2)
493            // but has to be split into multiple unnest logical plan instead
494            // a.k.a:
495            // - unnest(struct_col)
496            //      unnest(struct_arr_col) as struct_col
497
498            if let DataType::Struct(_) = data_type {
499                self.consecutive_unnest.push(None);
500            }
501            if self.top_most_unnest.is_none() {
502                self.top_most_unnest = Some(unnest_expr.clone());
503            }
504
505            Ok(Transformed::no(expr))
506        } else {
507            self.consecutive_unnest.push(None);
508            Ok(Transformed::no(expr))
509        }
510    }
511
512    /// The rewriting only happens when the traversal has reached the top-most unnest expr
513    /// within a sequence of consecutive unnest exprs node
514    ///
515    /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
516    /// ```text
517    ///                         ┌──────────────────┐
518    ///                         │    binaryexpr    │
519    ///                         │                  │
520    ///                         └──────────────────┘
521    ///                f_down  / /            │ │
522    ///                       / / f_up        │ │
523    ///                      / /        f_down│ │f_up
524    ///                  unnest               │ │
525    ///                                       │ │
526    ///       f_down  / / f_up(rewriting)     │ │
527    ///              / /
528    ///             / /                      unnest
529    ///         unnest
530    ///                           f_down  / / f_up(rewriting)
531    /// f_down / /f_up                   / /
532    ///       / /                       / /
533    ///      / /                    unnest
534    ///   column1
535    ///                     f_down / /f_up
536    ///                           / /
537    ///                          / /
538    ///                       column2
539    /// ```
540    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
541        if let Expr::Unnest(ref traversing_unnest) = expr {
542            if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
543                self.top_most_unnest = None;
544            }
545            // Find inside consecutive_unnest, the sequence of continuous unnest exprs
546
547            // Get the latest consecutive unnest exprs
548            // and check if current upward traversal is the returning to the root expr
549            // for example given a expr `unnest(unnest(col))` then the traversal happens like:
550            // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
551            // the result of such traversal is unnest(col, depth:=2)
552            let unnest_stack = self.get_latest_consecutive_unnest();
553
554            // This traversal has reached the top most unnest again
555            // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
556            // -> Unnest(2nd) -> Unnest(top) a.k.a here
557            // Thus
558            // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
559            if traversing_unnest == unnest_stack.last().unwrap() {
560                let most_inner = unnest_stack.first().unwrap();
561                let inner_expr = most_inner.expr.as_ref();
562                // unnest(unnest(struct_arr_col)) is not allow to be done recursively
563                // it needs to be split into multiple unnest logical plan
564                // unnest(struct_arr)
565                //  unnest(struct_arr_col) as struct_arr
566                // instead of unnest(struct_arr_col, depth = 2)
567
568                let unnest_recursion = unnest_stack.len();
569                let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1;
570
571                let mut transformed_exprs = self.transform(
572                    unnest_recursion,
573                    expr.schema_name().to_string(),
574                    inner_expr,
575                    struct_allowed,
576                )?;
577                if struct_allowed {
578                    self.transformed_root_exprs = Some(transformed_exprs.clone());
579                }
580                return Ok(Transformed::new(
581                    transformed_exprs.swap_remove(0),
582                    true,
583                    TreeNodeRecursion::Continue,
584                ));
585            }
586        } else {
587            self.consecutive_unnest.push(None);
588        }
589
590        // For column exprs that are not descendants of any unnest node
591        // retain their projection
592        // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
593        // this condition can be checked by maintaining an Option<top most unnest>
594        if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
595            push_projection_dedupl(self.inner_projection_exprs, expr.clone());
596        }
597
598        Ok(Transformed::no(expr))
599    }
600}
601
602fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
603    let schema_name = expr.schema_name().to_string();
604    if !projection
605        .iter()
606        .any(|e| e.schema_name().to_string() == schema_name)
607    {
608        projection.push(expr);
609    }
610}
611/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
612/// Given an expression which contains unnest expr as one of its children,
613/// Try transform depends on unnest type
614/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
615/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
616///
617/// The transformed exprs will be used in the outer projection
618/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
619/// is done only for the bottom expression
620pub(crate) fn rewrite_recursive_unnest_bottom_up(
621    input: &LogicalPlan,
622    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
623    inner_projection_exprs: &mut Vec<Expr>,
624    original_expr: &Expr,
625) -> Result<Vec<Expr>> {
626    let mut rewriter = RecursiveUnnestRewriter {
627        input_schema: input.schema(),
628        root_expr: original_expr,
629        top_most_unnest: None,
630        consecutive_unnest: vec![],
631        inner_projection_exprs,
632        columns_unnestings: unnest_placeholder_columns,
633        transformed_root_exprs: None,
634    };
635
636    // This transformation is only done for list unnest
637    // struct unnest is done at the root level, and at the later stage
638    // because the syntax of TreeNode only support transform into 1 Expr, while
639    // Unnest struct will be transformed into multiple Exprs
640    // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
641    //
642    // The transformation looks like:
643    // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
644    // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
645    let Transformed {
646        data: transformed_expr,
647        transformed,
648        tnr: _,
649    } = original_expr.clone().rewrite(&mut rewriter)?;
650
651    if !transformed {
652        // TODO: remove the next line after `Expr::Wildcard` is removed
653        #[expect(deprecated)]
654        if matches!(&transformed_expr, Expr::Column(_))
655            || matches!(&transformed_expr, Expr::Wildcard { .. })
656        {
657            push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
658            Ok(vec![transformed_expr])
659        } else {
660            // We need to evaluate the expr in the inner projection,
661            // outer projection just select its name
662            let column_name = transformed_expr.schema_name().to_string();
663            push_projection_dedupl(inner_projection_exprs, transformed_expr);
664            Ok(vec![Expr::Column(Column::from_name(column_name))])
665        }
666    } else {
667        if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
668            return Ok(transformed_root_exprs);
669        }
670        Ok(vec![transformed_expr])
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use std::{ops::Add, sync::Arc};
677
678    use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
679    use datafusion_common::{Column, DFSchema, Result};
680    use datafusion_expr::{
681        ColumnUnnestList, EmptyRelation, LogicalPlan, col, lit, unnest,
682    };
683    use datafusion_functions::core::expr_ext::FieldAccessor;
684    use datafusion_functions_aggregate::expr_fn::count;
685
686    use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
687    use indexmap::IndexMap;
688
689    fn column_unnests_eq(
690        l: Vec<&str>,
691        r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
692    ) {
693        let r_formatted: Vec<String> = r
694            .iter()
695            .map(|i| match i.1 {
696                None => format!("{}", i.0),
697                Some(vec) => format!(
698                    "{}=>[{}]",
699                    i.0,
700                    vec.iter()
701                        .map(|i| format!("{i}"))
702                        .collect::<Vec<String>>()
703                        .join(", ")
704                ),
705            })
706            .collect();
707        let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
708        assert_eq!(l_formatted, r_formatted);
709    }
710
711    #[test]
712    fn test_transform_bottom_unnest_recursive() -> Result<()> {
713        let schema = Schema::new(vec![
714            Field::new(
715                "3d_col",
716                ArrowDataType::List(Arc::new(Field::new(
717                    "2d_col",
718                    ArrowDataType::List(Arc::new(Field::new(
719                        "elements",
720                        ArrowDataType::Int64,
721                        true,
722                    ))),
723                    true,
724                ))),
725                true,
726            ),
727            Field::new("i64_col", ArrowDataType::Int64, true),
728        ]);
729
730        let dfschema = DFSchema::try_from(schema)?;
731
732        let input = LogicalPlan::EmptyRelation(EmptyRelation {
733            produce_one_row: false,
734            schema: Arc::new(dfschema),
735        });
736
737        let mut unnest_placeholder_columns = IndexMap::new();
738        let mut inner_projection_exprs = vec![];
739
740        // unnest(unnest(3d_col)) + unnest(unnest(3d_col))
741        let original_expr = unnest(unnest(col("3d_col")))
742            .add(unnest(unnest(col("3d_col"))))
743            .add(col("i64_col"));
744        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
745            &input,
746            &mut unnest_placeholder_columns,
747            &mut inner_projection_exprs,
748            &original_expr,
749        )?;
750        // Only the bottom most unnest exprs are transformed
751        assert_eq!(
752            transformed_exprs,
753            vec![
754                col("__unnest_placeholder(3d_col,depth=2)")
755                    .alias("UNNEST(UNNEST(3d_col))")
756                    .add(
757                        col("__unnest_placeholder(3d_col,depth=2)")
758                            .alias("UNNEST(UNNEST(3d_col))")
759                    )
760                    .add(col("i64_col"))
761            ]
762        );
763        column_unnests_eq(
764            vec![
765                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
766            ],
767            &unnest_placeholder_columns,
768        );
769
770        // Still reference struct_col in original schema but with alias,
771        // to avoid colliding with the projection on the column itself if any
772        assert_eq!(
773            inner_projection_exprs,
774            vec![
775                col("3d_col").alias("__unnest_placeholder(3d_col)"),
776                col("i64_col")
777            ]
778        );
779
780        // unnest(3d_col) as 2d_col
781        let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
782        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
783            &input,
784            &mut unnest_placeholder_columns,
785            &mut inner_projection_exprs,
786            &original_expr_2,
787        )?;
788
789        assert_eq!(
790            transformed_exprs,
791            vec![
792                (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
793                    .alias("2d_col")
794            ]
795        );
796        column_unnests_eq(
797            vec![
798                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]",
799            ],
800            &unnest_placeholder_columns,
801        );
802        // Still reference struct_col in original schema but with alias,
803        // to avoid colliding with the projection on the column itself if any
804        assert_eq!(
805            inner_projection_exprs,
806            vec![
807                col("3d_col").alias("__unnest_placeholder(3d_col)"),
808                col("i64_col")
809            ]
810        );
811
812        Ok(())
813    }
814
815    #[test]
816    fn test_transform_bottom_unnest() -> Result<()> {
817        let schema = Schema::new(vec![
818            Field::new(
819                "struct_col",
820                ArrowDataType::Struct(Fields::from(vec![
821                    Field::new("field1", ArrowDataType::Int32, false),
822                    Field::new("field2", ArrowDataType::Int32, false),
823                ])),
824                false,
825            ),
826            Field::new(
827                "array_col",
828                ArrowDataType::List(Arc::new(Field::new_list_field(
829                    ArrowDataType::Int64,
830                    true,
831                ))),
832                true,
833            ),
834            Field::new("int_col", ArrowDataType::Int32, false),
835        ]);
836
837        let dfschema = DFSchema::try_from(schema)?;
838
839        let input = LogicalPlan::EmptyRelation(EmptyRelation {
840            produce_one_row: false,
841            schema: Arc::new(dfschema),
842        });
843
844        let mut unnest_placeholder_columns = IndexMap::new();
845        let mut inner_projection_exprs = vec![];
846
847        // unnest(struct_col)
848        let original_expr = unnest(col("struct_col"));
849        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
850            &input,
851            &mut unnest_placeholder_columns,
852            &mut inner_projection_exprs,
853            &original_expr,
854        )?;
855        assert_eq!(
856            transformed_exprs,
857            vec![
858                col("__unnest_placeholder(struct_col).field1"),
859                col("__unnest_placeholder(struct_col).field2"),
860            ]
861        );
862        column_unnests_eq(
863            vec!["__unnest_placeholder(struct_col)"],
864            &unnest_placeholder_columns,
865        );
866        // Still reference struct_col in original schema but with alias,
867        // to avoid colliding with the projection on the column itself if any
868        assert_eq!(
869            inner_projection_exprs,
870            vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
871        );
872
873        // unnest(array_col) + 1
874        let original_expr = unnest(col("array_col")).add(lit(1i64));
875        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
876            &input,
877            &mut unnest_placeholder_columns,
878            &mut inner_projection_exprs,
879            &original_expr,
880        )?;
881        column_unnests_eq(
882            vec![
883                "__unnest_placeholder(struct_col)",
884                "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
885            ],
886            &unnest_placeholder_columns,
887        );
888        // Only transform the unnest children
889        assert_eq!(
890            transformed_exprs,
891            vec![
892                col("__unnest_placeholder(array_col,depth=1)")
893                    .alias("UNNEST(array_col)")
894                    .add(lit(1i64))
895            ]
896        );
897
898        // Keep appending to the current vector
899        // Still reference array_col in original schema but with alias,
900        // to avoid colliding with the projection on the column itself if any
901        assert_eq!(
902            inner_projection_exprs,
903            vec![
904                col("struct_col").alias("__unnest_placeholder(struct_col)"),
905                col("array_col").alias("__unnest_placeholder(array_col)")
906            ]
907        );
908
909        Ok(())
910    }
911
912    // Unnest -> field access -> unnest
913    #[test]
914    fn test_transform_non_consecutive_unnests() -> Result<()> {
915        // List of struct
916        // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
917        let schema = Schema::new(vec![
918            Field::new(
919                "struct_list",
920                ArrowDataType::List(Arc::new(Field::new(
921                    "element",
922                    ArrowDataType::Struct(Fields::from(vec![
923                        Field::new(
924                            // list of i64
925                            "subfield1",
926                            ArrowDataType::List(Arc::new(Field::new(
927                                "i64_element",
928                                ArrowDataType::Int64,
929                                true,
930                            ))),
931                            true,
932                        ),
933                        Field::new(
934                            // list of utf8
935                            "subfield2",
936                            ArrowDataType::List(Arc::new(Field::new(
937                                "utf8_element",
938                                ArrowDataType::Utf8,
939                                true,
940                            ))),
941                            true,
942                        ),
943                    ])),
944                    true,
945                ))),
946                true,
947            ),
948            Field::new("int_col", ArrowDataType::Int32, false),
949        ]);
950
951        let dfschema = DFSchema::try_from(schema)?;
952
953        let input = LogicalPlan::EmptyRelation(EmptyRelation {
954            produce_one_row: false,
955            schema: Arc::new(dfschema),
956        });
957
958        let mut unnest_placeholder_columns = IndexMap::new();
959        let mut inner_projection_exprs = vec![];
960
961        // An expr with multiple unnest
962        let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
963        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
964            &input,
965            &mut unnest_placeholder_columns,
966            &mut inner_projection_exprs,
967            &select_expr1,
968        )?;
969        // Only the inner most/ bottom most unnest is transformed
970        assert_eq!(
971            transformed_exprs,
972            vec![unnest(
973                col("__unnest_placeholder(struct_list,depth=1)")
974                    .alias("UNNEST(struct_list)")
975                    .field("subfield1")
976            )]
977        );
978
979        column_unnests_eq(
980            vec![
981                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
982            ],
983            &unnest_placeholder_columns,
984        );
985
986        assert_eq!(
987            inner_projection_exprs,
988            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
989        );
990
991        // continue rewrite another expr in select
992        let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
993        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
994            &input,
995            &mut unnest_placeholder_columns,
996            &mut inner_projection_exprs,
997            &select_expr2,
998        )?;
999        // Only the inner most/ bottom most unnest is transformed
1000        assert_eq!(
1001            transformed_exprs,
1002            vec![unnest(
1003                col("__unnest_placeholder(struct_list,depth=1)")
1004                    .alias("UNNEST(struct_list)")
1005                    .field("subfield2")
1006            )]
1007        );
1008
1009        // unnest place holder columns remain the same
1010        // because expr1 and expr2 derive from the same unnest result
1011        column_unnests_eq(
1012            vec![
1013                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1014            ],
1015            &unnest_placeholder_columns,
1016        );
1017
1018        assert_eq!(
1019            inner_projection_exprs,
1020            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1021        );
1022
1023        Ok(())
1024    }
1025
1026    #[test]
1027    fn test_resolve_positions_to_exprs() -> Result<()> {
1028        let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1029
1030        // Assert 1 resolved as first column in select list
1031        let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1032        assert_eq!(resolved, col("c1"));
1033
1034        // Assert error if index out of select clause bounds
1035        let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1036        assert!(resolved.is_err_and(|e| e.message().contains(
1037            "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1038        )));
1039
1040        let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1041        assert!(resolved.is_err_and(|e| e.message().contains(
1042            "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1043        )));
1044
1045        // Assert expression returned as-is
1046        let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1047        assert_eq!(resolved, lit("text"));
1048
1049        let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1050        assert_eq!(resolved, col("fake"));
1051
1052        Ok(())
1053    }
1054}