Skip to main content

datafusion_sql/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL Utility Functions
19
20use std::vec;
21
22use arrow::datatypes::{
23    DECIMAL_DEFAULT_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType,
24};
25use datafusion_common::tree_node::{
26    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29    Column, DFSchemaRef, Diagnostic, HashMap, Result, ScalarValue,
30    assert_or_internal_err, exec_datafusion_err, exec_err, internal_err, plan_err,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34    Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38    ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, col, expr_vec_fmt,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44/// Make a best-effort attempt at resolving all columns in the expression tree
45pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46    expr.clone()
47        .transform_up(|nested_expr| {
48            match nested_expr {
49                Expr::Column(col) => {
50                    let (qualifier, field) =
51                        plan.schema().qualified_field_from_column(&col)?;
52                    Ok(Transformed::yes(Expr::Column(Column::from((
53                        qualifier, field,
54                    )))))
55                }
56                _ => {
57                    // keep recursing
58                    Ok(Transformed::no(nested_expr))
59                }
60            }
61        })
62        .data()
63}
64
65/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
66///
67/// For example, the expression `a + b < 1` would require, as input, the 2
68/// individual columns, `a` and `b`. But, if the base expressions already
69/// contain the `a + b` result, then that may be used in lieu of the `a` and
70/// `b` columns.
71///
72/// This is useful in the context of a query like:
73///
74/// SELECT a + b < 1 ... GROUP BY a + b
75///
76/// where post-aggregation, `a + b` need not be a projection against the
77/// individual columns `a` and `b`, but rather it is a projection against the
78/// `a + b` found in the GROUP BY.
79pub(crate) fn rebase_expr(
80    expr: &Expr,
81    base_exprs: &[Expr],
82    plan: &LogicalPlan,
83) -> Result<Expr> {
84    expr.clone()
85        .transform_down(|nested_expr| {
86            if base_exprs.contains(&nested_expr) {
87                Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88            } else {
89                Ok(Transformed::no(nested_expr))
90            }
91        })
92        .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
97    Projection,
98    Having,
99    Qualify,
100    OrderBy,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub(crate) enum CheckColumnsSatisfyExprsPurpose {
105    Aggregate(CheckColumnsMustReferenceAggregatePurpose),
106}
107
108impl CheckColumnsSatisfyExprsPurpose {
109    fn message_prefix(&self) -> &'static str {
110        match self {
111            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
112                "Column in SELECT must be in GROUP BY or an aggregate function"
113            }
114            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
115                "Column in HAVING must be in GROUP BY or an aggregate function"
116            }
117            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
118                "Column in QUALIFY must be in GROUP BY or an aggregate function"
119            }
120            Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::OrderBy) => {
121                "Column in ORDER BY must be in GROUP BY or an aggregate function"
122            }
123        }
124    }
125
126    fn diagnostic_message(&self, expr: &Expr) -> String {
127        format!(
128            "'{expr}' must appear in GROUP BY clause because it's not an aggregate expression"
129        )
130    }
131}
132
133/// Determines if the set of `Expr`'s are a valid projection on the input
134/// `Expr::Column`'s.
135pub(crate) fn check_columns_satisfy_exprs(
136    columns: &[Expr],
137    exprs: &[Expr],
138    purpose: CheckColumnsSatisfyExprsPurpose,
139) -> Result<()> {
140    columns.iter().try_for_each(|c| match c {
141        Expr::Column(_) => Ok(()),
142        _ => internal_err!("Expr::Column are required"),
143    })?;
144    let column_exprs = find_column_exprs(exprs);
145    for e in &column_exprs {
146        match e {
147            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
148                for e in exprs {
149                    check_column_satisfies_expr(columns, e, purpose)?;
150                }
151            }
152            Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
153                for e in exprs {
154                    check_column_satisfies_expr(columns, e, purpose)?;
155                }
156            }
157            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
158                for exprs in lists_of_exprs {
159                    for e in exprs {
160                        check_column_satisfies_expr(columns, e, purpose)?;
161                    }
162                }
163            }
164            _ => check_column_satisfies_expr(columns, e, purpose)?,
165        }
166    }
167    Ok(())
168}
169
170fn check_column_satisfies_expr(
171    columns: &[Expr],
172    expr: &Expr,
173    purpose: CheckColumnsSatisfyExprsPurpose,
174) -> Result<()> {
175    if !columns.contains(expr) {
176        let diagnostic = Diagnostic::new_error(
177            purpose.diagnostic_message(expr),
178            expr.spans().and_then(|spans| spans.first()),
179        )
180        .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
181
182        return plan_err!(
183            "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
184            purpose.message_prefix(),
185            expr,
186            expr_vec_fmt!(columns);
187            diagnostic=diagnostic
188        );
189    }
190    Ok(())
191}
192
193/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
194/// aliasing.
195pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
196    exprs
197        .iter()
198        .filter_map(|expr| match expr {
199            Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
200            _ => None,
201        })
202        .collect::<HashMap<String, Expr>>()
203}
204
205/// Given an expression that's literal int encoding position, lookup the corresponding expression
206/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
207/// otherwise, returns planning error.
208/// If input expression is not an int literal, returns expression as-is.
209pub(crate) fn resolve_positions_to_exprs(
210    expr: Expr,
211    select_exprs: &[Expr],
212) -> Result<Expr> {
213    match expr {
214        // sql_expr_to_logical_expr maps number to i64
215        // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
216        Expr::Literal(ScalarValue::Int64(Some(position)), _)
217            if position > 0_i64 && position <= select_exprs.len() as i64 =>
218        {
219            let index = (position - 1) as usize;
220            let select_expr = &select_exprs[index];
221            Ok(match select_expr {
222                Expr::Alias(Alias { expr, .. }) => *expr.clone(),
223                _ => select_expr.clone(),
224            })
225        }
226        Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
227            "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
228            position,
229            select_exprs.len()
230        ),
231        _ => Ok(expr),
232    }
233}
234
235/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
236/// alias' underlying `Expr`.
237pub(crate) fn resolve_aliases_to_exprs(
238    expr: Expr,
239    aliases: &HashMap<String, Expr>,
240) -> Result<Expr> {
241    expr.transform_up(|nested_expr| match nested_expr {
242        Expr::Column(c) if c.relation.is_none() => {
243            if let Some(aliased_expr) = aliases.get(&c.name) {
244                Ok(Transformed::yes(aliased_expr.clone()))
245            } else {
246                Ok(Transformed::no(Expr::Column(c)))
247            }
248        }
249        _ => Ok(Transformed::no(nested_expr)),
250    })
251    .data()
252}
253
254/// Given a slice of window expressions sharing the same sort key, find their common partition
255/// keys.
256pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
257    let all_partition_keys = window_exprs
258        .iter()
259        .map(|expr| match expr {
260            Expr::WindowFunction(window_fun) => {
261                let WindowFunction {
262                    params: WindowFunctionParams { partition_by, .. },
263                    ..
264                } = window_fun.as_ref();
265                Ok(partition_by)
266            }
267            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
268                Expr::WindowFunction(window_fun) => {
269                    let WindowFunction {
270                        params: WindowFunctionParams { partition_by, .. },
271                        ..
272                    } = window_fun.as_ref();
273                    Ok(partition_by)
274                }
275                expr => exec_err!("Impossibly got non-window expr {expr:?}"),
276            },
277            expr => exec_err!("Impossibly got non-window expr {expr:?}"),
278        })
279        .collect::<Result<Vec<_>>>()?;
280    let result = all_partition_keys
281        .iter()
282        .min_by_key(|s| s.len())
283        .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
284    Ok(result)
285}
286
287/// Returns a validated `DataType` for the specified precision and
288/// scale
289pub(crate) fn make_decimal_type(
290    precision: Option<u64>,
291    scale: Option<u64>,
292) -> Result<DataType> {
293    // postgres like behavior
294    let (precision, scale) = match (precision, scale) {
295        (Some(p), Some(s)) => (p as u8, s as i8),
296        (Some(p), None) => (p as u8, 0),
297        (None, Some(_)) => {
298            return plan_err!("Cannot specify only scale for decimal data type");
299        }
300        (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
301    };
302
303    if precision == 0
304        || precision > DECIMAL256_MAX_PRECISION
305        || scale.unsigned_abs() > precision
306    {
307        plan_err!(
308            "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
309        )
310    } else if precision > DECIMAL128_MAX_PRECISION
311        && precision <= DECIMAL256_MAX_PRECISION
312    {
313        Ok(DataType::Decimal256(precision, scale))
314    } else {
315        Ok(DataType::Decimal128(precision, scale))
316    }
317}
318
319/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
320pub(crate) fn normalize_ident(id: Ident) -> String {
321    match id.quote_style {
322        Some(_) => id.value,
323        None => id.value.to_ascii_lowercase(),
324    }
325}
326
327pub(crate) fn value_to_string(value: &Value) -> Option<String> {
328    match value {
329        Value::SingleQuotedString(s) => Some(s.to_string()),
330        Value::DollarQuotedString(s) => Some(s.to_string()),
331        Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
332        Value::UnicodeStringLiteral(s) => Some(s.to_string()),
333        Value::EscapedStringLiteral(s) => Some(s.to_string()),
334        Value::QuoteDelimitedStringLiteral(s)
335        | Value::NationalQuoteDelimitedStringLiteral(s) => Some(s.value.to_string()),
336        Value::DoubleQuotedString(_)
337        | Value::NationalStringLiteral(_)
338        | Value::SingleQuotedByteStringLiteral(_)
339        | Value::DoubleQuotedByteStringLiteral(_)
340        | Value::TripleSingleQuotedString(_)
341        | Value::TripleDoubleQuotedString(_)
342        | Value::TripleSingleQuotedByteStringLiteral(_)
343        | Value::TripleDoubleQuotedByteStringLiteral(_)
344        | Value::SingleQuotedRawStringLiteral(_)
345        | Value::DoubleQuotedRawStringLiteral(_)
346        | Value::TripleSingleQuotedRawStringLiteral(_)
347        | Value::TripleDoubleQuotedRawStringLiteral(_)
348        | Value::HexStringLiteral(_)
349        | Value::Null
350        | Value::Placeholder(_) => None,
351    }
352}
353
354pub(crate) fn rewrite_recursive_unnests_bottom_up(
355    input: &LogicalPlan,
356    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
357    inner_projection_exprs: &mut Vec<Expr>,
358    original_exprs: &[Expr],
359) -> Result<Vec<Expr>> {
360    Ok(original_exprs
361        .iter()
362        .map(|expr| {
363            rewrite_recursive_unnest_bottom_up(
364                input,
365                unnest_placeholder_columns,
366                inner_projection_exprs,
367                expr,
368            )
369        })
370        .collect::<Result<Vec<_>>>()?
371        .into_iter()
372        .flatten()
373        .collect::<Vec<_>>())
374}
375
376pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
377
378/*
379This is only useful when used with transform down up
380A full example of how the transformation works:
381 */
382struct RecursiveUnnestRewriter<'a> {
383    input_schema: &'a DFSchemaRef,
384    root_expr: &'a Expr,
385    // Useful to detect which child expr is a part of/ not a part of unnest operation
386    top_most_unnest: Option<Unnest>,
387    consecutive_unnest: Vec<Option<Unnest>>,
388    inner_projection_exprs: &'a mut Vec<Expr>,
389    columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
390    transformed_root_exprs: Option<Vec<Expr>>,
391}
392impl RecursiveUnnestRewriter<'_> {
393    /// This struct stores the history of expr
394    /// during its tree-traversal with a notation of
395    /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
396    /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
397    ///
398    /// The first item will be the inner most expr
399    fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
400        self.consecutive_unnest
401            .iter()
402            .rev()
403            .skip_while(|item| item.is_none())
404            .take_while(|item| item.is_some())
405            .to_owned()
406            .cloned()
407            .map(|item| item.unwrap())
408            .collect()
409    }
410
411    /// Check if the current expression is at the root level for struct unnest purposes.
412    /// This is true if:
413    /// 1. The expression IS the root expression, OR
414    /// 2. The root expression is an Alias wrapping this expression
415    ///
416    /// This allows `unnest(struct_col) AS alias` to work, where the alias is simply
417    /// ignored for struct unnest (matching DuckDB behavior).
418    fn is_at_struct_allowed_root(&self, expr: &Expr) -> bool {
419        if expr == self.root_expr {
420            return true;
421        }
422        // Allow struct unnest when root is an alias wrapping the unnest
423        if let Expr::Alias(Alias { expr: inner, .. }) = self.root_expr {
424            return inner.as_ref() == expr;
425        }
426        false
427    }
428
429    fn transform(
430        &mut self,
431        level: usize,
432        alias_name: String,
433        expr_in_unnest: &Expr,
434        struct_allowed: bool,
435    ) -> Result<Vec<Expr>> {
436        let inner_expr_name = expr_in_unnest.schema_name().to_string();
437
438        // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
439        // inside unnest execution, each column inside the inner projection
440        // will be transformed into new columns. Thus we need to keep track of these placeholding column names
441        let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
442        let post_unnest_name =
443            format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
444        // This is due to the fact that unnest transformation should keep the original
445        // column name as is, to comply with group by and order by
446        let placeholder_column = Column::from_name(placeholder_name.clone());
447        let field = expr_in_unnest.to_field(self.input_schema)?.1;
448        let data_type = field.data_type();
449
450        match data_type {
451            DataType::Struct(inner_fields) => {
452                assert_or_internal_err!(
453                    struct_allowed,
454                    "unnest on struct can only be applied at the root level of select expression"
455                );
456                push_projection_dedupl(
457                    self.inner_projection_exprs,
458                    expr_in_unnest.clone().alias(placeholder_name.clone()),
459                );
460                self.columns_unnestings
461                    .insert(Column::from_name(placeholder_name.clone()), None);
462                Ok(get_struct_unnested_columns(&placeholder_name, inner_fields)
463                    .into_iter()
464                    .map(Expr::Column)
465                    .collect())
466            }
467            DataType::List(_)
468            | DataType::FixedSizeList(_, _)
469            | DataType::LargeList(_)
470            | DataType::ListView(_)
471            | DataType::LargeListView(_) => {
472                push_projection_dedupl(
473                    self.inner_projection_exprs,
474                    expr_in_unnest.clone().alias(placeholder_name.clone()),
475                );
476
477                let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
478                let list_unnesting = self
479                    .columns_unnestings
480                    .entry(placeholder_column)
481                    .or_insert(Some(vec![]));
482                let unnesting = ColumnUnnestList {
483                    output_column: Column::from_name(post_unnest_name),
484                    depth: level,
485                };
486                let list_unnestings = list_unnesting.as_mut().unwrap();
487                if !list_unnestings.contains(&unnesting) {
488                    list_unnestings.push(unnesting);
489                }
490                Ok(vec![post_unnest_expr])
491            }
492            _ => {
493                internal_err!("unnest on non-list or struct type is not supported")
494            }
495        }
496    }
497}
498
499impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
500    type Node = Expr;
501
502    /// This downward traversal needs to keep track of:
503    /// - Whether or not some unnest expr has been visited from the top until the current node
504    /// - If some unnest expr has been visited, maintain a stack of such information, this
505    ///   is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
506    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
507        if let Expr::Unnest(ref unnest_expr) = expr {
508            let field = unnest_expr.expr.to_field(self.input_schema)?.1;
509            let data_type = field.data_type();
510            self.consecutive_unnest.push(Some(unnest_expr.clone()));
511            // if expr inside unnest is a struct, do not consider
512            // the next unnest as consecutive unnest (if any)
513            // meaning unnest(unnest(struct_arr_col)) can't
514            // be interpreted as unnest(struct_arr_col, depth:=2)
515            // but has to be split into multiple unnest logical plan instead
516            // a.k.a:
517            // - unnest(struct_col)
518            //      unnest(struct_arr_col) as struct_col
519
520            if let DataType::Struct(_) = data_type {
521                self.consecutive_unnest.push(None);
522            }
523            if self.top_most_unnest.is_none() {
524                self.top_most_unnest = Some(unnest_expr.clone());
525            }
526
527            Ok(Transformed::no(expr))
528        } else {
529            self.consecutive_unnest.push(None);
530            Ok(Transformed::no(expr))
531        }
532    }
533
534    /// The rewriting only happens when the traversal has reached the top-most unnest expr
535    /// within a sequence of consecutive unnest exprs node
536    ///
537    /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
538    /// ```text
539    ///                         ┌──────────────────┐
540    ///                         │    binaryexpr    │
541    ///                         │                  │
542    ///                         └──────────────────┘
543    ///                f_down  / /            │ │
544    ///                       / / f_up        │ │
545    ///                      / /        f_down│ │f_up
546    ///                  unnest               │ │
547    ///                                       │ │
548    ///       f_down  / / f_up(rewriting)     │ │
549    ///              / /
550    ///             / /                      unnest
551    ///         unnest
552    ///                           f_down  / / f_up(rewriting)
553    /// f_down / /f_up                   / /
554    ///       / /                       / /
555    ///      / /                    unnest
556    ///   column1
557    ///                     f_down / /f_up
558    ///                           / /
559    ///                          / /
560    ///                       column2
561    /// ```
562    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
563        if let Expr::Unnest(ref traversing_unnest) = expr {
564            if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
565                self.top_most_unnest = None;
566            }
567            // Find inside consecutive_unnest, the sequence of continuous unnest exprs
568
569            // Get the latest consecutive unnest exprs
570            // and check if current upward traversal is the returning to the root expr
571            // for example given a expr `unnest(unnest(col))` then the traversal happens like:
572            // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
573            // the result of such traversal is unnest(col, depth:=2)
574            let unnest_stack = self.get_latest_consecutive_unnest();
575
576            // This traversal has reached the top most unnest again
577            // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
578            // -> Unnest(2nd) -> Unnest(top) a.k.a here
579            // Thus
580            // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
581            if traversing_unnest == unnest_stack.last().unwrap() {
582                let most_inner = unnest_stack.first().unwrap();
583                let inner_expr = most_inner.expr.as_ref();
584                // unnest(unnest(struct_arr_col)) is not allow to be done recursively
585                // it needs to be split into multiple unnest logical plan
586                // unnest(struct_arr)
587                //  unnest(struct_arr_col) as struct_arr
588                // instead of unnest(struct_arr_col, depth = 2)
589
590                let unnest_recursion = unnest_stack.len();
591                let struct_allowed =
592                    self.is_at_struct_allowed_root(&expr) && unnest_recursion == 1;
593
594                let mut transformed_exprs = self.transform(
595                    unnest_recursion,
596                    expr.schema_name().to_string(),
597                    inner_expr,
598                    struct_allowed,
599                )?;
600                // Only set transformed_root_exprs for struct unnest (which returns multiple expressions).
601                // For list unnest (single expression), we let the normal rewrite handle the alias.
602                if struct_allowed && transformed_exprs.len() > 1 {
603                    self.transformed_root_exprs = Some(transformed_exprs.clone());
604                }
605                return Ok(Transformed::new(
606                    transformed_exprs.swap_remove(0),
607                    true,
608                    TreeNodeRecursion::Continue,
609                ));
610            }
611        } else {
612            self.consecutive_unnest.push(None);
613        }
614
615        // For column exprs that are not descendants of any unnest node
616        // retain their projection
617        // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
618        // this condition can be checked by maintaining an Option<top most unnest>
619        if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
620            push_projection_dedupl(self.inner_projection_exprs, expr.clone());
621        }
622
623        Ok(Transformed::no(expr))
624    }
625}
626
627fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
628    let schema_name = expr.schema_name().to_string();
629    if !projection
630        .iter()
631        .any(|e| e.schema_name().to_string() == schema_name)
632    {
633        projection.push(expr);
634    }
635}
636/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
637/// Given an expression which contains unnest expr as one of its children,
638/// Try transform depends on unnest type
639/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
640/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
641///
642/// The transformed exprs will be used in the outer projection
643/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
644/// is done only for the bottom expression
645pub(crate) fn rewrite_recursive_unnest_bottom_up(
646    input: &LogicalPlan,
647    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
648    inner_projection_exprs: &mut Vec<Expr>,
649    original_expr: &Expr,
650) -> Result<Vec<Expr>> {
651    let mut rewriter = RecursiveUnnestRewriter {
652        input_schema: input.schema(),
653        root_expr: original_expr,
654        top_most_unnest: None,
655        consecutive_unnest: vec![],
656        inner_projection_exprs,
657        columns_unnestings: unnest_placeholder_columns,
658        transformed_root_exprs: None,
659    };
660
661    // This transformation is only done for list unnest
662    // struct unnest is done at the root level, and at the later stage
663    // because the syntax of TreeNode only support transform into 1 Expr, while
664    // Unnest struct will be transformed into multiple Exprs
665    // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
666    //
667    // The transformation looks like:
668    // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
669    // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
670    let Transformed {
671        data: transformed_expr,
672        transformed,
673        tnr: _,
674    } = original_expr.clone().rewrite(&mut rewriter)?;
675
676    if !transformed {
677        // TODO: remove the next line after `Expr::Wildcard` is removed
678        #[expect(deprecated)]
679        if matches!(&transformed_expr, Expr::Column(_))
680            || matches!(&transformed_expr, Expr::Wildcard { .. })
681        {
682            push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
683            Ok(vec![transformed_expr])
684        } else {
685            // We need to evaluate the expr in the inner projection,
686            // outer projection just select its name
687            let column_name = transformed_expr.schema_name().to_string();
688            push_projection_dedupl(inner_projection_exprs, transformed_expr);
689            Ok(vec![Expr::Column(Column::from_name(column_name))])
690        }
691    } else {
692        if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
693            return Ok(transformed_root_exprs);
694        }
695        Ok(vec![transformed_expr])
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use std::{ops::Add, sync::Arc};
702
703    use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
704    use datafusion_common::{Column, DFSchema, Result};
705    use datafusion_expr::{
706        ColumnUnnestList, EmptyRelation, LogicalPlan, col, lit, unnest,
707    };
708    use datafusion_functions::core::expr_ext::FieldAccessor;
709    use datafusion_functions_aggregate::expr_fn::count;
710
711    use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
712    use indexmap::IndexMap;
713
714    fn column_unnests_eq(
715        l: Vec<&str>,
716        r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
717    ) {
718        let r_formatted: Vec<String> = r
719            .iter()
720            .map(|i| match i.1 {
721                None => format!("{}", i.0),
722                Some(vec) => format!(
723                    "{}=>[{}]",
724                    i.0,
725                    vec.iter()
726                        .map(|i| format!("{i}"))
727                        .collect::<Vec<String>>()
728                        .join(", ")
729                ),
730            })
731            .collect();
732        let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
733        assert_eq!(l_formatted, r_formatted);
734    }
735
736    #[test]
737    fn test_transform_bottom_unnest_recursive() -> Result<()> {
738        let schema = Schema::new(vec![
739            Field::new(
740                "3d_col",
741                ArrowDataType::List(Arc::new(Field::new(
742                    "2d_col",
743                    ArrowDataType::List(Arc::new(Field::new(
744                        "elements",
745                        ArrowDataType::Int64,
746                        true,
747                    ))),
748                    true,
749                ))),
750                true,
751            ),
752            Field::new("i64_col", ArrowDataType::Int64, true),
753        ]);
754
755        let dfschema = DFSchema::try_from(schema)?;
756
757        let input = LogicalPlan::EmptyRelation(EmptyRelation {
758            produce_one_row: false,
759            schema: Arc::new(dfschema),
760        });
761
762        let mut unnest_placeholder_columns = IndexMap::new();
763        let mut inner_projection_exprs = vec![];
764
765        // unnest(unnest(3d_col)) + unnest(unnest(3d_col))
766        let original_expr = unnest(unnest(col("3d_col")))
767            .add(unnest(unnest(col("3d_col"))))
768            .add(col("i64_col"));
769        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
770            &input,
771            &mut unnest_placeholder_columns,
772            &mut inner_projection_exprs,
773            &original_expr,
774        )?;
775        // Only the bottom most unnest exprs are transformed
776        assert_eq!(
777            transformed_exprs,
778            vec![
779                col("__unnest_placeholder(3d_col,depth=2)")
780                    .alias("UNNEST(UNNEST(3d_col))")
781                    .add(
782                        col("__unnest_placeholder(3d_col,depth=2)")
783                            .alias("UNNEST(UNNEST(3d_col))")
784                    )
785                    .add(col("i64_col"))
786            ]
787        );
788        column_unnests_eq(
789            vec![
790                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
791            ],
792            &unnest_placeholder_columns,
793        );
794
795        // Still reference struct_col in original schema but with alias,
796        // to avoid colliding with the projection on the column itself if any
797        assert_eq!(
798            inner_projection_exprs,
799            vec![
800                col("3d_col").alias("__unnest_placeholder(3d_col)"),
801                col("i64_col")
802            ]
803        );
804
805        // unnest(3d_col) as 2d_col
806        let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
807        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
808            &input,
809            &mut unnest_placeholder_columns,
810            &mut inner_projection_exprs,
811            &original_expr_2,
812        )?;
813
814        assert_eq!(
815            transformed_exprs,
816            vec![
817                (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
818                    .alias("2d_col")
819            ]
820        );
821        column_unnests_eq(
822            vec![
823                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]",
824            ],
825            &unnest_placeholder_columns,
826        );
827        // Still reference struct_col in original schema but with alias,
828        // to avoid colliding with the projection on the column itself if any
829        assert_eq!(
830            inner_projection_exprs,
831            vec![
832                col("3d_col").alias("__unnest_placeholder(3d_col)"),
833                col("i64_col")
834            ]
835        );
836
837        Ok(())
838    }
839
840    #[test]
841    fn test_transform_bottom_unnest() -> Result<()> {
842        let schema = Schema::new(vec![
843            Field::new(
844                "struct_col",
845                ArrowDataType::Struct(Fields::from(vec![
846                    Field::new("field1", ArrowDataType::Int32, false),
847                    Field::new("field2", ArrowDataType::Int32, false),
848                ])),
849                false,
850            ),
851            Field::new(
852                "array_col",
853                ArrowDataType::List(Arc::new(Field::new_list_field(
854                    ArrowDataType::Int64,
855                    true,
856                ))),
857                true,
858            ),
859            Field::new("int_col", ArrowDataType::Int32, false),
860        ]);
861
862        let dfschema = DFSchema::try_from(schema)?;
863
864        let input = LogicalPlan::EmptyRelation(EmptyRelation {
865            produce_one_row: false,
866            schema: Arc::new(dfschema),
867        });
868
869        let mut unnest_placeholder_columns = IndexMap::new();
870        let mut inner_projection_exprs = vec![];
871
872        // unnest(struct_col)
873        let original_expr = unnest(col("struct_col"));
874        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
875            &input,
876            &mut unnest_placeholder_columns,
877            &mut inner_projection_exprs,
878            &original_expr,
879        )?;
880        assert_eq!(
881            transformed_exprs,
882            vec![
883                col("__unnest_placeholder(struct_col).field1"),
884                col("__unnest_placeholder(struct_col).field2"),
885            ]
886        );
887        column_unnests_eq(
888            vec!["__unnest_placeholder(struct_col)"],
889            &unnest_placeholder_columns,
890        );
891        // Still reference struct_col in original schema but with alias,
892        // to avoid colliding with the projection on the column itself if any
893        assert_eq!(
894            inner_projection_exprs,
895            vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
896        );
897
898        // unnest(array_col) + 1
899        let original_expr = unnest(col("array_col")).add(lit(1i64));
900        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
901            &input,
902            &mut unnest_placeholder_columns,
903            &mut inner_projection_exprs,
904            &original_expr,
905        )?;
906        column_unnests_eq(
907            vec![
908                "__unnest_placeholder(struct_col)",
909                "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
910            ],
911            &unnest_placeholder_columns,
912        );
913        // Only transform the unnest children
914        assert_eq!(
915            transformed_exprs,
916            vec![
917                col("__unnest_placeholder(array_col,depth=1)")
918                    .alias("UNNEST(array_col)")
919                    .add(lit(1i64))
920            ]
921        );
922
923        // Keep appending to the current vector
924        // Still reference array_col in original schema but with alias,
925        // to avoid colliding with the projection on the column itself if any
926        assert_eq!(
927            inner_projection_exprs,
928            vec![
929                col("struct_col").alias("__unnest_placeholder(struct_col)"),
930                col("array_col").alias("__unnest_placeholder(array_col)")
931            ]
932        );
933
934        Ok(())
935    }
936
937    // Unnest -> field access -> unnest
938    #[test]
939    fn test_transform_non_consecutive_unnests() -> Result<()> {
940        // List of struct
941        // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
942        let schema = Schema::new(vec![
943            Field::new(
944                "struct_list",
945                ArrowDataType::List(Arc::new(Field::new(
946                    "element",
947                    ArrowDataType::Struct(Fields::from(vec![
948                        Field::new(
949                            // list of i64
950                            "subfield1",
951                            ArrowDataType::List(Arc::new(Field::new(
952                                "i64_element",
953                                ArrowDataType::Int64,
954                                true,
955                            ))),
956                            true,
957                        ),
958                        Field::new(
959                            // list of utf8
960                            "subfield2",
961                            ArrowDataType::List(Arc::new(Field::new(
962                                "utf8_element",
963                                ArrowDataType::Utf8,
964                                true,
965                            ))),
966                            true,
967                        ),
968                    ])),
969                    true,
970                ))),
971                true,
972            ),
973            Field::new("int_col", ArrowDataType::Int32, false),
974        ]);
975
976        let dfschema = DFSchema::try_from(schema)?;
977
978        let input = LogicalPlan::EmptyRelation(EmptyRelation {
979            produce_one_row: false,
980            schema: Arc::new(dfschema),
981        });
982
983        let mut unnest_placeholder_columns = IndexMap::new();
984        let mut inner_projection_exprs = vec![];
985
986        // An expr with multiple unnest
987        let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
988        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
989            &input,
990            &mut unnest_placeholder_columns,
991            &mut inner_projection_exprs,
992            &select_expr1,
993        )?;
994        // Only the inner most/ bottom most unnest is transformed
995        assert_eq!(
996            transformed_exprs,
997            vec![unnest(
998                col("__unnest_placeholder(struct_list,depth=1)")
999                    .alias("UNNEST(struct_list)")
1000                    .field("subfield1")
1001            )]
1002        );
1003
1004        column_unnests_eq(
1005            vec![
1006                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1007            ],
1008            &unnest_placeholder_columns,
1009        );
1010
1011        assert_eq!(
1012            inner_projection_exprs,
1013            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1014        );
1015
1016        // continue rewrite another expr in select
1017        let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
1018        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
1019            &input,
1020            &mut unnest_placeholder_columns,
1021            &mut inner_projection_exprs,
1022            &select_expr2,
1023        )?;
1024        // Only the inner most/ bottom most unnest is transformed
1025        assert_eq!(
1026            transformed_exprs,
1027            vec![unnest(
1028                col("__unnest_placeholder(struct_list,depth=1)")
1029                    .alias("UNNEST(struct_list)")
1030                    .field("subfield2")
1031            )]
1032        );
1033
1034        // unnest place holder columns remain the same
1035        // because expr1 and expr2 derive from the same unnest result
1036        column_unnests_eq(
1037            vec![
1038                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1039            ],
1040            &unnest_placeholder_columns,
1041        );
1042
1043        assert_eq!(
1044            inner_projection_exprs,
1045            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1046        );
1047
1048        Ok(())
1049    }
1050
1051    #[test]
1052    fn test_resolve_positions_to_exprs() -> Result<()> {
1053        let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1054
1055        // Assert 1 resolved as first column in select list
1056        let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1057        assert_eq!(resolved, col("c1"));
1058
1059        // Assert error if index out of select clause bounds
1060        let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1061        assert!(resolved.is_err_and(|e| e.message().contains(
1062            "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1063        )));
1064
1065        let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1066        assert!(resolved.is_err_and(|e| e.message().contains(
1067            "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1068        )));
1069
1070        // Assert expression returned as-is
1071        let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1072        assert_eq!(resolved, lit("text"));
1073
1074        let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1075        assert_eq!(resolved, col("fake"));
1076
1077        Ok(())
1078    }
1079}