datafusion_sql/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL Utility Functions
19
20use std::vec;
21
22use arrow::datatypes::{
23    DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
24};
25use datafusion_common::tree_node::{
26    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29    exec_err, internal_err, plan_err, Column, DFSchemaRef, DataFusionError, Diagnostic,
30    HashMap, Result, ScalarValue,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34    Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38    col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44/// Make a best-effort attempt at resolving all columns in the expression tree
45pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46    expr.clone()
47        .transform_up(|nested_expr| {
48            match nested_expr {
49                Expr::Column(col) => {
50                    let (qualifier, field) =
51                        plan.schema().qualified_field_from_column(&col)?;
52                    Ok(Transformed::yes(Expr::Column(Column::from((
53                        qualifier, field,
54                    )))))
55                }
56                _ => {
57                    // keep recursing
58                    Ok(Transformed::no(nested_expr))
59                }
60            }
61        })
62        .data()
63}
64
65/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
66///
67/// For example, the expression `a + b < 1` would require, as input, the 2
68/// individual columns, `a` and `b`. But, if the base expressions already
69/// contain the `a + b` result, then that may be used in lieu of the `a` and
70/// `b` columns.
71///
72/// This is useful in the context of a query like:
73///
74/// SELECT a + b < 1 ... GROUP BY a + b
75///
76/// where post-aggregation, `a + b` need not be a projection against the
77/// individual columns `a` and `b`, but rather it is a projection against the
78/// `a + b` found in the GROUP BY.
79pub(crate) fn rebase_expr(
80    expr: &Expr,
81    base_exprs: &[Expr],
82    plan: &LogicalPlan,
83) -> Result<Expr> {
84    expr.clone()
85        .transform_down(|nested_expr| {
86            if base_exprs.contains(&nested_expr) {
87                Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88            } else {
89                Ok(Transformed::no(nested_expr))
90            }
91        })
92        .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsSatisfyExprsPurpose {
97    ProjectionMustReferenceAggregate,
98    HavingMustReferenceAggregate,
99}
100
101impl CheckColumnsSatisfyExprsPurpose {
102    fn message_prefix(&self) -> &'static str {
103        match self {
104            CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate => {
105                "Column in SELECT must be in GROUP BY or an aggregate function"
106            }
107            CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate => {
108                "Column in HAVING must be in GROUP BY or an aggregate function"
109            }
110        }
111    }
112
113    fn diagnostic_message(&self, expr: &Expr) -> String {
114        format!("'{expr}' must appear in GROUP BY clause because it's not an aggregate expression")
115    }
116}
117
118/// Determines if the set of `Expr`'s are a valid projection on the input
119/// `Expr::Column`'s.
120pub(crate) fn check_columns_satisfy_exprs(
121    columns: &[Expr],
122    exprs: &[Expr],
123    purpose: CheckColumnsSatisfyExprsPurpose,
124) -> Result<()> {
125    columns.iter().try_for_each(|c| match c {
126        Expr::Column(_) => Ok(()),
127        _ => internal_err!("Expr::Column are required"),
128    })?;
129    let column_exprs = find_column_exprs(exprs);
130    for e in &column_exprs {
131        match e {
132            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
133                for e in exprs {
134                    check_column_satisfies_expr(columns, e, purpose)?;
135                }
136            }
137            Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
138                for e in exprs {
139                    check_column_satisfies_expr(columns, e, purpose)?;
140                }
141            }
142            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
143                for exprs in lists_of_exprs {
144                    for e in exprs {
145                        check_column_satisfies_expr(columns, e, purpose)?;
146                    }
147                }
148            }
149            _ => check_column_satisfies_expr(columns, e, purpose)?,
150        }
151    }
152    Ok(())
153}
154
155fn check_column_satisfies_expr(
156    columns: &[Expr],
157    expr: &Expr,
158    purpose: CheckColumnsSatisfyExprsPurpose,
159) -> Result<()> {
160    if !columns.contains(expr) {
161        return plan_err!(
162            "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
163            purpose.message_prefix(),
164            expr,
165            expr_vec_fmt!(columns)
166        )
167        .map_err(|err| {
168            let diagnostic = Diagnostic::new_error(
169                purpose.diagnostic_message(expr),
170                expr.spans().and_then(|spans| spans.first()),
171            )
172            .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None);
173            err.with_diagnostic(diagnostic)
174        });
175    }
176    Ok(())
177}
178
179/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
180/// aliasing.
181pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
182    exprs
183        .iter()
184        .filter_map(|expr| match expr {
185            Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
186            _ => None,
187        })
188        .collect::<HashMap<String, Expr>>()
189}
190
191/// Given an expression that's literal int encoding position, lookup the corresponding expression
192/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
193/// otherwise, returns planning error.
194/// If input expression is not an int literal, returns expression as-is.
195pub(crate) fn resolve_positions_to_exprs(
196    expr: Expr,
197    select_exprs: &[Expr],
198) -> Result<Expr> {
199    match expr {
200        // sql_expr_to_logical_expr maps number to i64
201        // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
202        Expr::Literal(ScalarValue::Int64(Some(position)))
203            if position > 0_i64 && position <= select_exprs.len() as i64 =>
204        {
205            let index = (position - 1) as usize;
206            let select_expr = &select_exprs[index];
207            Ok(match select_expr {
208                Expr::Alias(Alias { expr, .. }) => *expr.clone(),
209                _ => select_expr.clone(),
210            })
211        }
212        Expr::Literal(ScalarValue::Int64(Some(position))) => plan_err!(
213            "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
214            position, select_exprs.len()
215        ),
216        _ => Ok(expr),
217    }
218}
219
220/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
221/// alias' underlying `Expr`.
222pub(crate) fn resolve_aliases_to_exprs(
223    expr: Expr,
224    aliases: &HashMap<String, Expr>,
225) -> Result<Expr> {
226    expr.transform_up(|nested_expr| match nested_expr {
227        Expr::Column(c) if c.relation.is_none() => {
228            if let Some(aliased_expr) = aliases.get(&c.name) {
229                Ok(Transformed::yes(aliased_expr.clone()))
230            } else {
231                Ok(Transformed::no(Expr::Column(c)))
232            }
233        }
234        _ => Ok(Transformed::no(nested_expr)),
235    })
236    .data()
237}
238
239/// Given a slice of window expressions sharing the same sort key, find their common partition
240/// keys.
241pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
242    let all_partition_keys = window_exprs
243        .iter()
244        .map(|expr| match expr {
245            Expr::WindowFunction(WindowFunction {
246                params: WindowFunctionParams { partition_by, .. },
247                ..
248            }) => Ok(partition_by),
249            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
250                Expr::WindowFunction(WindowFunction {
251                    params: WindowFunctionParams { partition_by, .. },
252                    ..
253                }) => Ok(partition_by),
254                expr => exec_err!("Impossibly got non-window expr {expr:?}"),
255            },
256            expr => exec_err!("Impossibly got non-window expr {expr:?}"),
257        })
258        .collect::<Result<Vec<_>>>()?;
259    let result = all_partition_keys
260        .iter()
261        .min_by_key(|s| s.len())
262        .ok_or_else(|| {
263            DataFusionError::Execution("No window expressions found".to_owned())
264        })?;
265    Ok(result)
266}
267
268/// Returns a validated `DataType` for the specified precision and
269/// scale
270pub(crate) fn make_decimal_type(
271    precision: Option<u64>,
272    scale: Option<u64>,
273) -> Result<DataType> {
274    // postgres like behavior
275    let (precision, scale) = match (precision, scale) {
276        (Some(p), Some(s)) => (p as u8, s as i8),
277        (Some(p), None) => (p as u8, 0),
278        (None, Some(_)) => {
279            return plan_err!("Cannot specify only scale for decimal data type")
280        }
281        (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
282    };
283
284    if precision == 0
285        || precision > DECIMAL256_MAX_PRECISION
286        || scale.unsigned_abs() > precision
287    {
288        plan_err!(
289            "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
290        )
291    } else if precision > DECIMAL128_MAX_PRECISION
292        && precision <= DECIMAL256_MAX_PRECISION
293    {
294        Ok(DataType::Decimal256(precision, scale))
295    } else {
296        Ok(DataType::Decimal128(precision, scale))
297    }
298}
299
300/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
301pub(crate) fn normalize_ident(id: Ident) -> String {
302    match id.quote_style {
303        Some(_) => id.value,
304        None => id.value.to_ascii_lowercase(),
305    }
306}
307
308pub(crate) fn value_to_string(value: &Value) -> Option<String> {
309    match value {
310        Value::SingleQuotedString(s) => Some(s.to_string()),
311        Value::DollarQuotedString(s) => Some(s.to_string()),
312        Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
313        Value::UnicodeStringLiteral(s) => Some(s.to_string()),
314        Value::EscapedStringLiteral(s) => Some(s.to_string()),
315        Value::DoubleQuotedString(_)
316        | Value::NationalStringLiteral(_)
317        | Value::SingleQuotedByteStringLiteral(_)
318        | Value::DoubleQuotedByteStringLiteral(_)
319        | Value::TripleSingleQuotedString(_)
320        | Value::TripleDoubleQuotedString(_)
321        | Value::TripleSingleQuotedByteStringLiteral(_)
322        | Value::TripleDoubleQuotedByteStringLiteral(_)
323        | Value::SingleQuotedRawStringLiteral(_)
324        | Value::DoubleQuotedRawStringLiteral(_)
325        | Value::TripleSingleQuotedRawStringLiteral(_)
326        | Value::TripleDoubleQuotedRawStringLiteral(_)
327        | Value::HexStringLiteral(_)
328        | Value::Null
329        | Value::Placeholder(_) => None,
330    }
331}
332
333pub(crate) fn rewrite_recursive_unnests_bottom_up(
334    input: &LogicalPlan,
335    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
336    inner_projection_exprs: &mut Vec<Expr>,
337    original_exprs: &[Expr],
338) -> Result<Vec<Expr>> {
339    Ok(original_exprs
340        .iter()
341        .map(|expr| {
342            rewrite_recursive_unnest_bottom_up(
343                input,
344                unnest_placeholder_columns,
345                inner_projection_exprs,
346                expr,
347            )
348        })
349        .collect::<Result<Vec<_>>>()?
350        .into_iter()
351        .flatten()
352        .collect::<Vec<_>>())
353}
354
355pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
356
357/*
358This is only usedful when used with transform down up
359A full example of how the transformation works:
360 */
361struct RecursiveUnnestRewriter<'a> {
362    input_schema: &'a DFSchemaRef,
363    root_expr: &'a Expr,
364    // Useful to detect which child expr is a part of/ not a part of unnest operation
365    top_most_unnest: Option<Unnest>,
366    consecutive_unnest: Vec<Option<Unnest>>,
367    inner_projection_exprs: &'a mut Vec<Expr>,
368    columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
369    transformed_root_exprs: Option<Vec<Expr>>,
370}
371impl RecursiveUnnestRewriter<'_> {
372    /// This struct stores the history of expr
373    /// during its tree-traversal with a notation of
374    /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
375    /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
376    ///
377    /// The first item will be the inner most expr
378    fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
379        self.consecutive_unnest
380            .iter()
381            .rev()
382            .skip_while(|item| item.is_none())
383            .take_while(|item| item.is_some())
384            .to_owned()
385            .cloned()
386            .map(|item| item.unwrap())
387            .collect()
388    }
389
390    fn transform(
391        &mut self,
392        level: usize,
393        alias_name: String,
394        expr_in_unnest: &Expr,
395        struct_allowed: bool,
396    ) -> Result<Vec<Expr>> {
397        let inner_expr_name = expr_in_unnest.schema_name().to_string();
398
399        // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
400        // inside unnest execution, each column inside the inner projection
401        // will be transformed into new columns. Thus we need to keep track of these placeholding column names
402        let placeholder_name = format!("{UNNEST_PLACEHOLDER}({})", inner_expr_name);
403        let post_unnest_name =
404            format!("{UNNEST_PLACEHOLDER}({},depth={})", inner_expr_name, level);
405        // This is due to the fact that unnest transformation should keep the original
406        // column name as is, to comply with group by and order by
407        let placeholder_column = Column::from_name(placeholder_name.clone());
408
409        let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?;
410
411        match data_type {
412            DataType::Struct(inner_fields) => {
413                if !struct_allowed {
414                    return internal_err!("unnest on struct can only be applied at the root level of select expression");
415                }
416                push_projection_dedupl(
417                    self.inner_projection_exprs,
418                    expr_in_unnest.clone().alias(placeholder_name.clone()),
419                );
420                self.columns_unnestings
421                    .insert(Column::from_name(placeholder_name.clone()), None);
422                Ok(
423                    get_struct_unnested_columns(&placeholder_name, &inner_fields)
424                        .into_iter()
425                        .map(Expr::Column)
426                        .collect(),
427                )
428            }
429            DataType::List(_)
430            | DataType::FixedSizeList(_, _)
431            | DataType::LargeList(_) => {
432                push_projection_dedupl(
433                    self.inner_projection_exprs,
434                    expr_in_unnest.clone().alias(placeholder_name.clone()),
435                );
436
437                let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
438                let list_unnesting = self
439                    .columns_unnestings
440                    .entry(placeholder_column)
441                    .or_insert(Some(vec![]));
442                let unnesting = ColumnUnnestList {
443                    output_column: Column::from_name(post_unnest_name),
444                    depth: level,
445                };
446                let list_unnestings = list_unnesting.as_mut().unwrap();
447                if !list_unnestings.contains(&unnesting) {
448                    list_unnestings.push(unnesting);
449                }
450                Ok(vec![post_unnest_expr])
451            }
452            _ => {
453                internal_err!("unnest on non-list or struct type is not supported")
454            }
455        }
456    }
457}
458
459impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
460    type Node = Expr;
461
462    /// This downward traversal needs to keep track of:
463    /// - Whether or not some unnest expr has been visited from the top util the current node
464    /// - If some unnest expr has been visited, maintain a stack of such information, this
465    ///   is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
466    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
467        if let Expr::Unnest(ref unnest_expr) = expr {
468            let (data_type, _) =
469                unnest_expr.expr.data_type_and_nullable(self.input_schema)?;
470            self.consecutive_unnest.push(Some(unnest_expr.clone()));
471            // if expr inside unnest is a struct, do not consider
472            // the next unnest as consecutive unnest (if any)
473            // meaning unnest(unnest(struct_arr_col)) can't
474            // be interpreted as unnest(struct_arr_col, depth:=2)
475            // but has to be split into multiple unnest logical plan instead
476            // a.k.a:
477            // - unnest(struct_col)
478            //      unnest(struct_arr_col) as struct_col
479
480            if let DataType::Struct(_) = data_type {
481                self.consecutive_unnest.push(None);
482            }
483            if self.top_most_unnest.is_none() {
484                self.top_most_unnest = Some(unnest_expr.clone());
485            }
486
487            Ok(Transformed::no(expr))
488        } else {
489            self.consecutive_unnest.push(None);
490            Ok(Transformed::no(expr))
491        }
492    }
493
494    /// The rewriting only happens when the traversal has reached the top-most unnest expr
495    /// within a sequence of consecutive unnest exprs node
496    ///
497    /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
498    /// ```text
499    ///                         ┌──────────────────┐
500    ///                         │    binaryexpr    │
501    ///                         │                  │
502    ///                         └──────────────────┘
503    ///                f_down  / /            │ │
504    ///                       / / f_up        │ │
505    ///                      / /        f_down│ │f_up
506    ///                  unnest               │ │
507    ///                                       │ │
508    ///       f_down  / / f_up(rewriting)     │ │
509    ///              / /
510    ///             / /                      unnest
511    ///         unnest
512    ///                           f_down  / / f_up(rewriting)
513    /// f_down / /f_up                   / /
514    ///       / /                       / /
515    ///      / /                    unnest
516    ///   column1
517    ///                     f_down / /f_up
518    ///                           / /
519    ///                          / /
520    ///                       column2
521    /// ```
522    ///
523    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
524        if let Expr::Unnest(ref traversing_unnest) = expr {
525            if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
526                self.top_most_unnest = None;
527            }
528            // Find inside consecutive_unnest, the sequence of continuous unnest exprs
529
530            // Get the latest consecutive unnest exprs
531            // and check if current upward traversal is the returning to the root expr
532            // for example given a expr `unnest(unnest(col))` then the traversal happens like:
533            // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
534            // the result of such traversal is unnest(col, depth:=2)
535            let unnest_stack = self.get_latest_consecutive_unnest();
536
537            // This traversal has reached the top most unnest again
538            // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
539            // -> Unnest(2nd) -> Unnest(top) a.k.a here
540            // Thus
541            // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
542            if traversing_unnest == unnest_stack.last().unwrap() {
543                let most_inner = unnest_stack.first().unwrap();
544                let inner_expr = most_inner.expr.as_ref();
545                // unnest(unnest(struct_arr_col)) is not allow to be done recursively
546                // it needs to be splitted into multiple unnest logical plan
547                // unnest(struct_arr)
548                //  unnest(struct_arr_col) as struct_arr
549                // instead of unnest(struct_arr_col, depth = 2)
550
551                let unnest_recursion = unnest_stack.len();
552                let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1;
553
554                let mut transformed_exprs = self.transform(
555                    unnest_recursion,
556                    expr.schema_name().to_string(),
557                    inner_expr,
558                    struct_allowed,
559                )?;
560                if struct_allowed {
561                    self.transformed_root_exprs = Some(transformed_exprs.clone());
562                }
563                return Ok(Transformed::new(
564                    transformed_exprs.swap_remove(0),
565                    true,
566                    TreeNodeRecursion::Continue,
567                ));
568            }
569        } else {
570            self.consecutive_unnest.push(None);
571        }
572
573        // For column exprs that are not descendants of any unnest node
574        // retain their projection
575        // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
576        // this condition can be checked by maintaining an Option<top most unnest>
577        if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
578            push_projection_dedupl(self.inner_projection_exprs, expr.clone());
579        }
580
581        Ok(Transformed::no(expr))
582    }
583}
584
585fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
586    let schema_name = expr.schema_name().to_string();
587    if !projection
588        .iter()
589        .any(|e| e.schema_name().to_string() == schema_name)
590    {
591        projection.push(expr);
592    }
593}
594/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
595/// Given an expression which contains unnest expr as one of its children,
596/// Try transform depends on unnest type
597/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
598/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
599///
600/// The transformed exprs will be used in the outer projection
601/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
602/// is done only for the bottom expression
603pub(crate) fn rewrite_recursive_unnest_bottom_up(
604    input: &LogicalPlan,
605    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
606    inner_projection_exprs: &mut Vec<Expr>,
607    original_expr: &Expr,
608) -> Result<Vec<Expr>> {
609    let mut rewriter = RecursiveUnnestRewriter {
610        input_schema: input.schema(),
611        root_expr: original_expr,
612        top_most_unnest: None,
613        consecutive_unnest: vec![],
614        inner_projection_exprs,
615        columns_unnestings: unnest_placeholder_columns,
616        transformed_root_exprs: None,
617    };
618
619    // This transformation is only done for list unnest
620    // struct unnest is done at the root level, and at the later stage
621    // because the syntax of TreeNode only support transform into 1 Expr, while
622    // Unnest struct will be transformed into multiple Exprs
623    // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
624    //
625    // The transformation looks like:
626    // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
627    // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
628    let Transformed {
629        data: transformed_expr,
630        transformed,
631        tnr: _,
632    } = original_expr.clone().rewrite(&mut rewriter)?;
633
634    if !transformed {
635        // TODO: remove the next line after `Expr::Wildcard` is removed
636        #[expect(deprecated)]
637        if matches!(&transformed_expr, Expr::Column(_))
638            || matches!(&transformed_expr, Expr::Wildcard { .. })
639        {
640            push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
641            Ok(vec![transformed_expr])
642        } else {
643            // We need to evaluate the expr in the inner projection,
644            // outer projection just select its name
645            let column_name = transformed_expr.schema_name().to_string();
646            push_projection_dedupl(inner_projection_exprs, transformed_expr);
647            Ok(vec![Expr::Column(Column::from_name(column_name))])
648        }
649    } else {
650        if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
651            return Ok(transformed_root_exprs);
652        }
653        Ok(vec![transformed_expr])
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use std::{ops::Add, sync::Arc};
660
661    use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
662    use datafusion_common::{Column, DFSchema, Result};
663    use datafusion_expr::{
664        col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan,
665    };
666    use datafusion_functions::core::expr_ext::FieldAccessor;
667    use datafusion_functions_aggregate::expr_fn::count;
668
669    use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
670    use indexmap::IndexMap;
671
672    fn column_unnests_eq(
673        l: Vec<&str>,
674        r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
675    ) {
676        let r_formatted: Vec<String> = r
677            .iter()
678            .map(|i| match i.1 {
679                None => format!("{}", i.0),
680                Some(vec) => format!(
681                    "{}=>[{}]",
682                    i.0,
683                    vec.iter()
684                        .map(|i| format!("{}", i))
685                        .collect::<Vec<String>>()
686                        .join(", ")
687                ),
688            })
689            .collect();
690        let l_formatted: Vec<String> = l.iter().map(|i| i.to_string()).collect();
691        assert_eq!(l_formatted, r_formatted);
692    }
693
694    #[test]
695    fn test_transform_bottom_unnest_recursive() -> Result<()> {
696        let schema = Schema::new(vec![
697            Field::new(
698                "3d_col",
699                ArrowDataType::List(Arc::new(Field::new(
700                    "2d_col",
701                    ArrowDataType::List(Arc::new(Field::new(
702                        "elements",
703                        ArrowDataType::Int64,
704                        true,
705                    ))),
706                    true,
707                ))),
708                true,
709            ),
710            Field::new("i64_col", ArrowDataType::Int64, true),
711        ]);
712
713        let dfschema = DFSchema::try_from(schema)?;
714
715        let input = LogicalPlan::EmptyRelation(EmptyRelation {
716            produce_one_row: false,
717            schema: Arc::new(dfschema),
718        });
719
720        let mut unnest_placeholder_columns = IndexMap::new();
721        let mut inner_projection_exprs = vec![];
722
723        // unnest(unnest(3d_col)) + unnest(unnest(3d_col))
724        let original_expr = unnest(unnest(col("3d_col")))
725            .add(unnest(unnest(col("3d_col"))))
726            .add(col("i64_col"));
727        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
728            &input,
729            &mut unnest_placeholder_columns,
730            &mut inner_projection_exprs,
731            &original_expr,
732        )?;
733        // Only the bottom most unnest exprs are transformed
734        assert_eq!(
735            transformed_exprs,
736            vec![col("__unnest_placeholder(3d_col,depth=2)")
737                .alias("UNNEST(UNNEST(3d_col))")
738                .add(
739                    col("__unnest_placeholder(3d_col,depth=2)")
740                        .alias("UNNEST(UNNEST(3d_col))")
741                )
742                .add(col("i64_col"))]
743        );
744        column_unnests_eq(
745            vec![
746                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
747            ],
748            &unnest_placeholder_columns,
749        );
750
751        // Still reference struct_col in original schema but with alias,
752        // to avoid colliding with the projection on the column itself if any
753        assert_eq!(
754            inner_projection_exprs,
755            vec![
756                col("3d_col").alias("__unnest_placeholder(3d_col)"),
757                col("i64_col")
758            ]
759        );
760
761        // unnest(3d_col) as 2d_col
762        let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
763        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
764            &input,
765            &mut unnest_placeholder_columns,
766            &mut inner_projection_exprs,
767            &original_expr_2,
768        )?;
769
770        assert_eq!(
771            transformed_exprs,
772            vec![
773                (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
774                    .alias("2d_col")
775            ]
776        );
777        column_unnests_eq(
778            vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"],
779            &unnest_placeholder_columns,
780        );
781        // Still reference struct_col in original schema but with alias,
782        // to avoid colliding with the projection on the column itself if any
783        assert_eq!(
784            inner_projection_exprs,
785            vec![
786                col("3d_col").alias("__unnest_placeholder(3d_col)"),
787                col("i64_col")
788            ]
789        );
790
791        Ok(())
792    }
793
794    #[test]
795    fn test_transform_bottom_unnest() -> Result<()> {
796        let schema = Schema::new(vec![
797            Field::new(
798                "struct_col",
799                ArrowDataType::Struct(Fields::from(vec![
800                    Field::new("field1", ArrowDataType::Int32, false),
801                    Field::new("field2", ArrowDataType::Int32, false),
802                ])),
803                false,
804            ),
805            Field::new(
806                "array_col",
807                ArrowDataType::List(Arc::new(Field::new_list_field(
808                    ArrowDataType::Int64,
809                    true,
810                ))),
811                true,
812            ),
813            Field::new("int_col", ArrowDataType::Int32, false),
814        ]);
815
816        let dfschema = DFSchema::try_from(schema)?;
817
818        let input = LogicalPlan::EmptyRelation(EmptyRelation {
819            produce_one_row: false,
820            schema: Arc::new(dfschema),
821        });
822
823        let mut unnest_placeholder_columns = IndexMap::new();
824        let mut inner_projection_exprs = vec![];
825
826        // unnest(struct_col)
827        let original_expr = unnest(col("struct_col"));
828        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
829            &input,
830            &mut unnest_placeholder_columns,
831            &mut inner_projection_exprs,
832            &original_expr,
833        )?;
834        assert_eq!(
835            transformed_exprs,
836            vec![
837                col("__unnest_placeholder(struct_col).field1"),
838                col("__unnest_placeholder(struct_col).field2"),
839            ]
840        );
841        column_unnests_eq(
842            vec!["__unnest_placeholder(struct_col)"],
843            &unnest_placeholder_columns,
844        );
845        // Still reference struct_col in original schema but with alias,
846        // to avoid colliding with the projection on the column itself if any
847        assert_eq!(
848            inner_projection_exprs,
849            vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
850        );
851
852        // unnest(array_col) + 1
853        let original_expr = unnest(col("array_col")).add(lit(1i64));
854        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
855            &input,
856            &mut unnest_placeholder_columns,
857            &mut inner_projection_exprs,
858            &original_expr,
859        )?;
860        column_unnests_eq(
861            vec![
862                "__unnest_placeholder(struct_col)",
863                "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
864            ],
865            &unnest_placeholder_columns,
866        );
867        // Only transform the unnest children
868        assert_eq!(
869            transformed_exprs,
870            vec![col("__unnest_placeholder(array_col,depth=1)")
871                .alias("UNNEST(array_col)")
872                .add(lit(1i64))]
873        );
874
875        // Keep appending to the current vector
876        // Still reference array_col in original schema but with alias,
877        // to avoid colliding with the projection on the column itself if any
878        assert_eq!(
879            inner_projection_exprs,
880            vec![
881                col("struct_col").alias("__unnest_placeholder(struct_col)"),
882                col("array_col").alias("__unnest_placeholder(array_col)")
883            ]
884        );
885
886        Ok(())
887    }
888
889    // Unnest -> field access -> unnest
890    #[test]
891    fn test_transform_non_consecutive_unnests() -> Result<()> {
892        // List of struct
893        // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
894        let schema = Schema::new(vec![
895            Field::new(
896                "struct_list",
897                ArrowDataType::List(Arc::new(Field::new(
898                    "element",
899                    ArrowDataType::Struct(Fields::from(vec![
900                        Field::new(
901                            // list of i64
902                            "subfield1",
903                            ArrowDataType::List(Arc::new(Field::new(
904                                "i64_element",
905                                ArrowDataType::Int64,
906                                true,
907                            ))),
908                            true,
909                        ),
910                        Field::new(
911                            // list of utf8
912                            "subfield2",
913                            ArrowDataType::List(Arc::new(Field::new(
914                                "utf8_element",
915                                ArrowDataType::Utf8,
916                                true,
917                            ))),
918                            true,
919                        ),
920                    ])),
921                    true,
922                ))),
923                true,
924            ),
925            Field::new("int_col", ArrowDataType::Int32, false),
926        ]);
927
928        let dfschema = DFSchema::try_from(schema)?;
929
930        let input = LogicalPlan::EmptyRelation(EmptyRelation {
931            produce_one_row: false,
932            schema: Arc::new(dfschema),
933        });
934
935        let mut unnest_placeholder_columns = IndexMap::new();
936        let mut inner_projection_exprs = vec![];
937
938        // An expr with multiple unnest
939        let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
940        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
941            &input,
942            &mut unnest_placeholder_columns,
943            &mut inner_projection_exprs,
944            &select_expr1,
945        )?;
946        // Only the inner most/ bottom most unnest is transformed
947        assert_eq!(
948            transformed_exprs,
949            vec![unnest(
950                col("__unnest_placeholder(struct_list,depth=1)")
951                    .alias("UNNEST(struct_list)")
952                    .field("subfield1")
953            )]
954        );
955
956        column_unnests_eq(
957            vec![
958                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
959            ],
960            &unnest_placeholder_columns,
961        );
962
963        assert_eq!(
964            inner_projection_exprs,
965            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
966        );
967
968        // continue rewrite another expr in select
969        let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
970        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
971            &input,
972            &mut unnest_placeholder_columns,
973            &mut inner_projection_exprs,
974            &select_expr2,
975        )?;
976        // Only the inner most/ bottom most unnest is transformed
977        assert_eq!(
978            transformed_exprs,
979            vec![unnest(
980                col("__unnest_placeholder(struct_list,depth=1)")
981                    .alias("UNNEST(struct_list)")
982                    .field("subfield2")
983            )]
984        );
985
986        // unnest place holder columns remain the same
987        // because expr1 and expr2 derive from the same unnest result
988        column_unnests_eq(
989            vec![
990                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
991            ],
992            &unnest_placeholder_columns,
993        );
994
995        assert_eq!(
996            inner_projection_exprs,
997            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
998        );
999
1000        Ok(())
1001    }
1002
1003    #[test]
1004    fn test_resolve_positions_to_exprs() -> Result<()> {
1005        let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1006
1007        // Assert 1 resolved as first column in select list
1008        let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1009        assert_eq!(resolved, col("c1"));
1010
1011        // Assert error if index out of select clause bounds
1012        let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1013        assert!(resolved.is_err_and(|e| e.message().contains(
1014            "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1015        )));
1016
1017        let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1018        assert!(resolved.is_err_and(|e| e.message().contains(
1019            "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1020        )));
1021
1022        // Assert expression returned as-is
1023        let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1024        assert_eq!(resolved, lit("text"));
1025
1026        let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1027        assert_eq!(resolved, col("fake"));
1028
1029        Ok(())
1030    }
1031}