datafusion_sql/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL Utility Functions
19
20use std::vec;
21
22use arrow::datatypes::{
23    DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
24};
25use datafusion_common::tree_node::{
26    Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29    exec_err, internal_err, plan_err, Column, DFSchemaRef, DataFusionError, Diagnostic,
30    HashMap, Result, ScalarValue,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34    Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38    col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44/// Make a best-effort attempt at resolving all columns in the expression tree
45pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46    expr.clone()
47        .transform_up(|nested_expr| {
48            match nested_expr {
49                Expr::Column(col) => {
50                    let (qualifier, field) =
51                        plan.schema().qualified_field_from_column(&col)?;
52                    Ok(Transformed::yes(Expr::Column(Column::from((
53                        qualifier, field,
54                    )))))
55                }
56                _ => {
57                    // keep recursing
58                    Ok(Transformed::no(nested_expr))
59                }
60            }
61        })
62        .data()
63}
64
65/// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s.
66///
67/// For example, the expression `a + b < 1` would require, as input, the 2
68/// individual columns, `a` and `b`. But, if the base expressions already
69/// contain the `a + b` result, then that may be used in lieu of the `a` and
70/// `b` columns.
71///
72/// This is useful in the context of a query like:
73///
74/// SELECT a + b < 1 ... GROUP BY a + b
75///
76/// where post-aggregation, `a + b` need not be a projection against the
77/// individual columns `a` and `b`, but rather it is a projection against the
78/// `a + b` found in the GROUP BY.
79pub(crate) fn rebase_expr(
80    expr: &Expr,
81    base_exprs: &[Expr],
82    plan: &LogicalPlan,
83) -> Result<Expr> {
84    expr.clone()
85        .transform_down(|nested_expr| {
86            if base_exprs.contains(&nested_expr) {
87                Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88            } else {
89                Ok(Transformed::no(nested_expr))
90            }
91        })
92        .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsSatisfyExprsPurpose {
97    ProjectionMustReferenceAggregate,
98    HavingMustReferenceAggregate,
99}
100
101impl CheckColumnsSatisfyExprsPurpose {
102    fn message_prefix(&self) -> &'static str {
103        match self {
104            CheckColumnsSatisfyExprsPurpose::ProjectionMustReferenceAggregate => {
105                "Column in SELECT must be in GROUP BY or an aggregate function"
106            }
107            CheckColumnsSatisfyExprsPurpose::HavingMustReferenceAggregate => {
108                "Column in HAVING must be in GROUP BY or an aggregate function"
109            }
110        }
111    }
112
113    fn diagnostic_message(&self, expr: &Expr) -> String {
114        format!("'{expr}' must appear in GROUP BY clause because it's not an aggregate expression")
115    }
116}
117
118/// Determines if the set of `Expr`'s are a valid projection on the input
119/// `Expr::Column`'s.
120pub(crate) fn check_columns_satisfy_exprs(
121    columns: &[Expr],
122    exprs: &[Expr],
123    purpose: CheckColumnsSatisfyExprsPurpose,
124) -> Result<()> {
125    columns.iter().try_for_each(|c| match c {
126        Expr::Column(_) => Ok(()),
127        _ => internal_err!("Expr::Column are required"),
128    })?;
129    let column_exprs = find_column_exprs(exprs);
130    for e in &column_exprs {
131        match e {
132            Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
133                for e in exprs {
134                    check_column_satisfies_expr(columns, e, purpose)?;
135                }
136            }
137            Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
138                for e in exprs {
139                    check_column_satisfies_expr(columns, e, purpose)?;
140                }
141            }
142            Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
143                for exprs in lists_of_exprs {
144                    for e in exprs {
145                        check_column_satisfies_expr(columns, e, purpose)?;
146                    }
147                }
148            }
149            _ => check_column_satisfies_expr(columns, e, purpose)?,
150        }
151    }
152    Ok(())
153}
154
155fn check_column_satisfies_expr(
156    columns: &[Expr],
157    expr: &Expr,
158    purpose: CheckColumnsSatisfyExprsPurpose,
159) -> Result<()> {
160    if !columns.contains(expr) {
161        let diagnostic = Diagnostic::new_error(
162            purpose.diagnostic_message(expr),
163            expr.spans().and_then(|spans| spans.first()),
164        )
165        .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregare function like ANY_VALUE({expr})"), None);
166
167        return plan_err!(
168            "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
169            purpose.message_prefix(),
170            expr,
171            expr_vec_fmt!(columns);
172            diagnostic=diagnostic
173        );
174    }
175    Ok(())
176}
177
178/// Returns mapping of each alias (`String`) to the expression (`Expr`) it is
179/// aliasing.
180pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
181    exprs
182        .iter()
183        .filter_map(|expr| match expr {
184            Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
185            _ => None,
186        })
187        .collect::<HashMap<String, Expr>>()
188}
189
190/// Given an expression that's literal int encoding position, lookup the corresponding expression
191/// in the select_exprs list, if the index is within the bounds and it is indeed a position literal,
192/// otherwise, returns planning error.
193/// If input expression is not an int literal, returns expression as-is.
194pub(crate) fn resolve_positions_to_exprs(
195    expr: Expr,
196    select_exprs: &[Expr],
197) -> Result<Expr> {
198    match expr {
199        // sql_expr_to_logical_expr maps number to i64
200        // https://github.com/apache/datafusion/blob/8d175c759e17190980f270b5894348dc4cff9bbf/datafusion/src/sql/planner.rs#L882-L887
201        Expr::Literal(ScalarValue::Int64(Some(position)), _)
202            if position > 0_i64 && position <= select_exprs.len() as i64 =>
203        {
204            let index = (position - 1) as usize;
205            let select_expr = &select_exprs[index];
206            Ok(match select_expr {
207                Expr::Alias(Alias { expr, .. }) => *expr.clone(),
208                _ => select_expr.clone(),
209            })
210        }
211        Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
212            "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
213            position, select_exprs.len()
214        ),
215        _ => Ok(expr),
216    }
217}
218
219/// Rebuilds an `Expr` with columns that refer to aliases replaced by the
220/// alias' underlying `Expr`.
221pub(crate) fn resolve_aliases_to_exprs(
222    expr: Expr,
223    aliases: &HashMap<String, Expr>,
224) -> Result<Expr> {
225    expr.transform_up(|nested_expr| match nested_expr {
226        Expr::Column(c) if c.relation.is_none() => {
227            if let Some(aliased_expr) = aliases.get(&c.name) {
228                Ok(Transformed::yes(aliased_expr.clone()))
229            } else {
230                Ok(Transformed::no(Expr::Column(c)))
231            }
232        }
233        _ => Ok(Transformed::no(nested_expr)),
234    })
235    .data()
236}
237
238/// Given a slice of window expressions sharing the same sort key, find their common partition
239/// keys.
240pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
241    let all_partition_keys = window_exprs
242        .iter()
243        .map(|expr| match expr {
244            Expr::WindowFunction(window_fun) => {
245                let WindowFunction {
246                    params: WindowFunctionParams { partition_by, .. },
247                    ..
248                } = window_fun.as_ref();
249                Ok(partition_by)
250            }
251            Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
252                Expr::WindowFunction(window_fun) => {
253                    let WindowFunction {
254                        params: WindowFunctionParams { partition_by, .. },
255                        ..
256                    } = window_fun.as_ref();
257                    Ok(partition_by)
258                }
259                expr => exec_err!("Impossibly got non-window expr {expr:?}"),
260            },
261            expr => exec_err!("Impossibly got non-window expr {expr:?}"),
262        })
263        .collect::<Result<Vec<_>>>()?;
264    let result = all_partition_keys
265        .iter()
266        .min_by_key(|s| s.len())
267        .ok_or_else(|| {
268            DataFusionError::Execution("No window expressions found".to_owned())
269        })?;
270    Ok(result)
271}
272
273/// Returns a validated `DataType` for the specified precision and
274/// scale
275pub(crate) fn make_decimal_type(
276    precision: Option<u64>,
277    scale: Option<u64>,
278) -> Result<DataType> {
279    // postgres like behavior
280    let (precision, scale) = match (precision, scale) {
281        (Some(p), Some(s)) => (p as u8, s as i8),
282        (Some(p), None) => (p as u8, 0),
283        (None, Some(_)) => {
284            return plan_err!("Cannot specify only scale for decimal data type")
285        }
286        (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
287    };
288
289    if precision == 0
290        || precision > DECIMAL256_MAX_PRECISION
291        || scale.unsigned_abs() > precision
292    {
293        plan_err!(
294            "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
295        )
296    } else if precision > DECIMAL128_MAX_PRECISION
297        && precision <= DECIMAL256_MAX_PRECISION
298    {
299        Ok(DataType::Decimal256(precision, scale))
300    } else {
301        Ok(DataType::Decimal128(precision, scale))
302    }
303}
304
305/// Normalize an owned identifier to a lowercase string, unless the identifier is quoted.
306pub(crate) fn normalize_ident(id: Ident) -> String {
307    match id.quote_style {
308        Some(_) => id.value,
309        None => id.value.to_ascii_lowercase(),
310    }
311}
312
313pub(crate) fn value_to_string(value: &Value) -> Option<String> {
314    match value {
315        Value::SingleQuotedString(s) => Some(s.to_string()),
316        Value::DollarQuotedString(s) => Some(s.to_string()),
317        Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
318        Value::UnicodeStringLiteral(s) => Some(s.to_string()),
319        Value::EscapedStringLiteral(s) => Some(s.to_string()),
320        Value::DoubleQuotedString(_)
321        | Value::NationalStringLiteral(_)
322        | Value::SingleQuotedByteStringLiteral(_)
323        | Value::DoubleQuotedByteStringLiteral(_)
324        | Value::TripleSingleQuotedString(_)
325        | Value::TripleDoubleQuotedString(_)
326        | Value::TripleSingleQuotedByteStringLiteral(_)
327        | Value::TripleDoubleQuotedByteStringLiteral(_)
328        | Value::SingleQuotedRawStringLiteral(_)
329        | Value::DoubleQuotedRawStringLiteral(_)
330        | Value::TripleSingleQuotedRawStringLiteral(_)
331        | Value::TripleDoubleQuotedRawStringLiteral(_)
332        | Value::HexStringLiteral(_)
333        | Value::Null
334        | Value::Placeholder(_) => None,
335    }
336}
337
338pub(crate) fn rewrite_recursive_unnests_bottom_up(
339    input: &LogicalPlan,
340    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
341    inner_projection_exprs: &mut Vec<Expr>,
342    original_exprs: &[Expr],
343) -> Result<Vec<Expr>> {
344    Ok(original_exprs
345        .iter()
346        .map(|expr| {
347            rewrite_recursive_unnest_bottom_up(
348                input,
349                unnest_placeholder_columns,
350                inner_projection_exprs,
351                expr,
352            )
353        })
354        .collect::<Result<Vec<_>>>()?
355        .into_iter()
356        .flatten()
357        .collect::<Vec<_>>())
358}
359
360pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
361
362/*
363This is only usedful when used with transform down up
364A full example of how the transformation works:
365 */
366struct RecursiveUnnestRewriter<'a> {
367    input_schema: &'a DFSchemaRef,
368    root_expr: &'a Expr,
369    // Useful to detect which child expr is a part of/ not a part of unnest operation
370    top_most_unnest: Option<Unnest>,
371    consecutive_unnest: Vec<Option<Unnest>>,
372    inner_projection_exprs: &'a mut Vec<Expr>,
373    columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
374    transformed_root_exprs: Option<Vec<Expr>>,
375}
376impl RecursiveUnnestRewriter<'_> {
377    /// This struct stores the history of expr
378    /// during its tree-traversal with a notation of
379    /// \[None,**Unnest(exprA)**,**Unnest(exprB)**,None,None\]
380    /// then this function will returns \[**Unnest(exprA)**,**Unnest(exprB)**\]
381    ///
382    /// The first item will be the inner most expr
383    fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
384        self.consecutive_unnest
385            .iter()
386            .rev()
387            .skip_while(|item| item.is_none())
388            .take_while(|item| item.is_some())
389            .to_owned()
390            .cloned()
391            .map(|item| item.unwrap())
392            .collect()
393    }
394
395    fn transform(
396        &mut self,
397        level: usize,
398        alias_name: String,
399        expr_in_unnest: &Expr,
400        struct_allowed: bool,
401    ) -> Result<Vec<Expr>> {
402        let inner_expr_name = expr_in_unnest.schema_name().to_string();
403
404        // Full context, we are trying to plan the execution as InnerProjection->Unnest->OuterProjection
405        // inside unnest execution, each column inside the inner projection
406        // will be transformed into new columns. Thus we need to keep track of these placeholding column names
407        let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
408        let post_unnest_name =
409            format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
410        // This is due to the fact that unnest transformation should keep the original
411        // column name as is, to comply with group by and order by
412        let placeholder_column = Column::from_name(placeholder_name.clone());
413
414        let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?;
415
416        match data_type {
417            DataType::Struct(inner_fields) => {
418                if !struct_allowed {
419                    return internal_err!("unnest on struct can only be applied at the root level of select expression");
420                }
421                push_projection_dedupl(
422                    self.inner_projection_exprs,
423                    expr_in_unnest.clone().alias(placeholder_name.clone()),
424                );
425                self.columns_unnestings
426                    .insert(Column::from_name(placeholder_name.clone()), None);
427                Ok(
428                    get_struct_unnested_columns(&placeholder_name, &inner_fields)
429                        .into_iter()
430                        .map(Expr::Column)
431                        .collect(),
432                )
433            }
434            DataType::List(_)
435            | DataType::FixedSizeList(_, _)
436            | DataType::LargeList(_) => {
437                push_projection_dedupl(
438                    self.inner_projection_exprs,
439                    expr_in_unnest.clone().alias(placeholder_name.clone()),
440                );
441
442                let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
443                let list_unnesting = self
444                    .columns_unnestings
445                    .entry(placeholder_column)
446                    .or_insert(Some(vec![]));
447                let unnesting = ColumnUnnestList {
448                    output_column: Column::from_name(post_unnest_name),
449                    depth: level,
450                };
451                let list_unnestings = list_unnesting.as_mut().unwrap();
452                if !list_unnestings.contains(&unnesting) {
453                    list_unnestings.push(unnesting);
454                }
455                Ok(vec![post_unnest_expr])
456            }
457            _ => {
458                internal_err!("unnest on non-list or struct type is not supported")
459            }
460        }
461    }
462}
463
464impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
465    type Node = Expr;
466
467    /// This downward traversal needs to keep track of:
468    /// - Whether or not some unnest expr has been visited from the top util the current node
469    /// - If some unnest expr has been visited, maintain a stack of such information, this
470    ///   is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))**
471    fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
472        if let Expr::Unnest(ref unnest_expr) = expr {
473            let (data_type, _) =
474                unnest_expr.expr.data_type_and_nullable(self.input_schema)?;
475            self.consecutive_unnest.push(Some(unnest_expr.clone()));
476            // if expr inside unnest is a struct, do not consider
477            // the next unnest as consecutive unnest (if any)
478            // meaning unnest(unnest(struct_arr_col)) can't
479            // be interpreted as unnest(struct_arr_col, depth:=2)
480            // but has to be split into multiple unnest logical plan instead
481            // a.k.a:
482            // - unnest(struct_col)
483            //      unnest(struct_arr_col) as struct_col
484
485            if let DataType::Struct(_) = data_type {
486                self.consecutive_unnest.push(None);
487            }
488            if self.top_most_unnest.is_none() {
489                self.top_most_unnest = Some(unnest_expr.clone());
490            }
491
492            Ok(Transformed::no(expr))
493        } else {
494            self.consecutive_unnest.push(None);
495            Ok(Transformed::no(expr))
496        }
497    }
498
499    /// The rewriting only happens when the traversal has reached the top-most unnest expr
500    /// within a sequence of consecutive unnest exprs node
501    ///
502    /// For example an expr of **unnest(unnest(column1)) + unnest(unnest(unnest(column2)))**
503    /// ```text
504    ///                         ┌──────────────────┐
505    ///                         │    binaryexpr    │
506    ///                         │                  │
507    ///                         └──────────────────┘
508    ///                f_down  / /            │ │
509    ///                       / / f_up        │ │
510    ///                      / /        f_down│ │f_up
511    ///                  unnest               │ │
512    ///                                       │ │
513    ///       f_down  / / f_up(rewriting)     │ │
514    ///              / /
515    ///             / /                      unnest
516    ///         unnest
517    ///                           f_down  / / f_up(rewriting)
518    /// f_down / /f_up                   / /
519    ///       / /                       / /
520    ///      / /                    unnest
521    ///   column1
522    ///                     f_down / /f_up
523    ///                           / /
524    ///                          / /
525    ///                       column2
526    /// ```
527    ///
528    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
529        if let Expr::Unnest(ref traversing_unnest) = expr {
530            if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
531                self.top_most_unnest = None;
532            }
533            // Find inside consecutive_unnest, the sequence of continuous unnest exprs
534
535            // Get the latest consecutive unnest exprs
536            // and check if current upward traversal is the returning to the root expr
537            // for example given a expr `unnest(unnest(col))` then the traversal happens like:
538            // down(unnest) -> down(unnest) -> down(col) -> up(col) -> up(unnest) -> up(unnest)
539            // the result of such traversal is unnest(col, depth:=2)
540            let unnest_stack = self.get_latest_consecutive_unnest();
541
542            // This traversal has reached the top most unnest again
543            // e.g Unnest(top) -> Unnest(2nd) -> Column(bottom)
544            // -> Unnest(2nd) -> Unnest(top) a.k.a here
545            // Thus
546            // Unnest(Unnest(some_col)) is rewritten into Unnest(some_col, depth:=2)
547            if traversing_unnest == unnest_stack.last().unwrap() {
548                let most_inner = unnest_stack.first().unwrap();
549                let inner_expr = most_inner.expr.as_ref();
550                // unnest(unnest(struct_arr_col)) is not allow to be done recursively
551                // it needs to be splitted into multiple unnest logical plan
552                // unnest(struct_arr)
553                //  unnest(struct_arr_col) as struct_arr
554                // instead of unnest(struct_arr_col, depth = 2)
555
556                let unnest_recursion = unnest_stack.len();
557                let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1;
558
559                let mut transformed_exprs = self.transform(
560                    unnest_recursion,
561                    expr.schema_name().to_string(),
562                    inner_expr,
563                    struct_allowed,
564                )?;
565                if struct_allowed {
566                    self.transformed_root_exprs = Some(transformed_exprs.clone());
567                }
568                return Ok(Transformed::new(
569                    transformed_exprs.swap_remove(0),
570                    true,
571                    TreeNodeRecursion::Continue,
572                ));
573            }
574        } else {
575            self.consecutive_unnest.push(None);
576        }
577
578        // For column exprs that are not descendants of any unnest node
579        // retain their projection
580        // e.g given expr tree unnest(col_a) + col_b, we have to retain projection of col_b
581        // this condition can be checked by maintaining an Option<top most unnest>
582        if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
583            push_projection_dedupl(self.inner_projection_exprs, expr.clone());
584        }
585
586        Ok(Transformed::no(expr))
587    }
588}
589
590fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
591    let schema_name = expr.schema_name().to_string();
592    if !projection
593        .iter()
594        .any(|e| e.schema_name().to_string() == schema_name)
595    {
596        projection.push(expr);
597    }
598}
599/// The context is we want to rewrite unnest() into InnerProjection->Unnest->OuterProjection
600/// Given an expression which contains unnest expr as one of its children,
601/// Try transform depends on unnest type
602/// - For list column: unnest(col) with type list -> unnest(col) with type list::item
603/// - For struct column: unnest(struct(field1, field2)) -> unnest(struct).field1, unnest(struct).field2
604///
605/// The transformed exprs will be used in the outer projection
606/// If along the path from root to bottom, there are multiple unnest expressions, the transformation
607/// is done only for the bottom expression
608pub(crate) fn rewrite_recursive_unnest_bottom_up(
609    input: &LogicalPlan,
610    unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
611    inner_projection_exprs: &mut Vec<Expr>,
612    original_expr: &Expr,
613) -> Result<Vec<Expr>> {
614    let mut rewriter = RecursiveUnnestRewriter {
615        input_schema: input.schema(),
616        root_expr: original_expr,
617        top_most_unnest: None,
618        consecutive_unnest: vec![],
619        inner_projection_exprs,
620        columns_unnestings: unnest_placeholder_columns,
621        transformed_root_exprs: None,
622    };
623
624    // This transformation is only done for list unnest
625    // struct unnest is done at the root level, and at the later stage
626    // because the syntax of TreeNode only support transform into 1 Expr, while
627    // Unnest struct will be transformed into multiple Exprs
628    // TODO: This can be resolved after this issue is resolved: https://github.com/apache/datafusion/issues/10102
629    //
630    // The transformation looks like:
631    // - unnest(array_col) will be transformed into Column("unnest_place_holder(array_col)")
632    // - unnest(array_col) + 1 will be transformed into Column("unnest_place_holder(array_col) + 1")
633    let Transformed {
634        data: transformed_expr,
635        transformed,
636        tnr: _,
637    } = original_expr.clone().rewrite(&mut rewriter)?;
638
639    if !transformed {
640        // TODO: remove the next line after `Expr::Wildcard` is removed
641        #[expect(deprecated)]
642        if matches!(&transformed_expr, Expr::Column(_))
643            || matches!(&transformed_expr, Expr::Wildcard { .. })
644        {
645            push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
646            Ok(vec![transformed_expr])
647        } else {
648            // We need to evaluate the expr in the inner projection,
649            // outer projection just select its name
650            let column_name = transformed_expr.schema_name().to_string();
651            push_projection_dedupl(inner_projection_exprs, transformed_expr);
652            Ok(vec![Expr::Column(Column::from_name(column_name))])
653        }
654    } else {
655        if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
656            return Ok(transformed_root_exprs);
657        }
658        Ok(vec![transformed_expr])
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use std::{ops::Add, sync::Arc};
665
666    use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
667    use datafusion_common::{Column, DFSchema, Result};
668    use datafusion_expr::{
669        col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan,
670    };
671    use datafusion_functions::core::expr_ext::FieldAccessor;
672    use datafusion_functions_aggregate::expr_fn::count;
673
674    use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
675    use indexmap::IndexMap;
676
677    fn column_unnests_eq(
678        l: Vec<&str>,
679        r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
680    ) {
681        let r_formatted: Vec<String> = r
682            .iter()
683            .map(|i| match i.1 {
684                None => format!("{}", i.0),
685                Some(vec) => format!(
686                    "{}=>[{}]",
687                    i.0,
688                    vec.iter()
689                        .map(|i| format!("{i}"))
690                        .collect::<Vec<String>>()
691                        .join(", ")
692                ),
693            })
694            .collect();
695        let l_formatted: Vec<String> = l.iter().map(|i| i.to_string()).collect();
696        assert_eq!(l_formatted, r_formatted);
697    }
698
699    #[test]
700    fn test_transform_bottom_unnest_recursive() -> Result<()> {
701        let schema = Schema::new(vec![
702            Field::new(
703                "3d_col",
704                ArrowDataType::List(Arc::new(Field::new(
705                    "2d_col",
706                    ArrowDataType::List(Arc::new(Field::new(
707                        "elements",
708                        ArrowDataType::Int64,
709                        true,
710                    ))),
711                    true,
712                ))),
713                true,
714            ),
715            Field::new("i64_col", ArrowDataType::Int64, true),
716        ]);
717
718        let dfschema = DFSchema::try_from(schema)?;
719
720        let input = LogicalPlan::EmptyRelation(EmptyRelation {
721            produce_one_row: false,
722            schema: Arc::new(dfschema),
723        });
724
725        let mut unnest_placeholder_columns = IndexMap::new();
726        let mut inner_projection_exprs = vec![];
727
728        // unnest(unnest(3d_col)) + unnest(unnest(3d_col))
729        let original_expr = unnest(unnest(col("3d_col")))
730            .add(unnest(unnest(col("3d_col"))))
731            .add(col("i64_col"));
732        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
733            &input,
734            &mut unnest_placeholder_columns,
735            &mut inner_projection_exprs,
736            &original_expr,
737        )?;
738        // Only the bottom most unnest exprs are transformed
739        assert_eq!(
740            transformed_exprs,
741            vec![col("__unnest_placeholder(3d_col,depth=2)")
742                .alias("UNNEST(UNNEST(3d_col))")
743                .add(
744                    col("__unnest_placeholder(3d_col,depth=2)")
745                        .alias("UNNEST(UNNEST(3d_col))")
746                )
747                .add(col("i64_col"))]
748        );
749        column_unnests_eq(
750            vec![
751                "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
752            ],
753            &unnest_placeholder_columns,
754        );
755
756        // Still reference struct_col in original schema but with alias,
757        // to avoid colliding with the projection on the column itself if any
758        assert_eq!(
759            inner_projection_exprs,
760            vec![
761                col("3d_col").alias("__unnest_placeholder(3d_col)"),
762                col("i64_col")
763            ]
764        );
765
766        // unnest(3d_col) as 2d_col
767        let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
768        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
769            &input,
770            &mut unnest_placeholder_columns,
771            &mut inner_projection_exprs,
772            &original_expr_2,
773        )?;
774
775        assert_eq!(
776            transformed_exprs,
777            vec![
778                (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
779                    .alias("2d_col")
780            ]
781        );
782        column_unnests_eq(
783            vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"],
784            &unnest_placeholder_columns,
785        );
786        // Still reference struct_col in original schema but with alias,
787        // to avoid colliding with the projection on the column itself if any
788        assert_eq!(
789            inner_projection_exprs,
790            vec![
791                col("3d_col").alias("__unnest_placeholder(3d_col)"),
792                col("i64_col")
793            ]
794        );
795
796        Ok(())
797    }
798
799    #[test]
800    fn test_transform_bottom_unnest() -> Result<()> {
801        let schema = Schema::new(vec![
802            Field::new(
803                "struct_col",
804                ArrowDataType::Struct(Fields::from(vec![
805                    Field::new("field1", ArrowDataType::Int32, false),
806                    Field::new("field2", ArrowDataType::Int32, false),
807                ])),
808                false,
809            ),
810            Field::new(
811                "array_col",
812                ArrowDataType::List(Arc::new(Field::new_list_field(
813                    ArrowDataType::Int64,
814                    true,
815                ))),
816                true,
817            ),
818            Field::new("int_col", ArrowDataType::Int32, false),
819        ]);
820
821        let dfschema = DFSchema::try_from(schema)?;
822
823        let input = LogicalPlan::EmptyRelation(EmptyRelation {
824            produce_one_row: false,
825            schema: Arc::new(dfschema),
826        });
827
828        let mut unnest_placeholder_columns = IndexMap::new();
829        let mut inner_projection_exprs = vec![];
830
831        // unnest(struct_col)
832        let original_expr = unnest(col("struct_col"));
833        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
834            &input,
835            &mut unnest_placeholder_columns,
836            &mut inner_projection_exprs,
837            &original_expr,
838        )?;
839        assert_eq!(
840            transformed_exprs,
841            vec![
842                col("__unnest_placeholder(struct_col).field1"),
843                col("__unnest_placeholder(struct_col).field2"),
844            ]
845        );
846        column_unnests_eq(
847            vec!["__unnest_placeholder(struct_col)"],
848            &unnest_placeholder_columns,
849        );
850        // Still reference struct_col in original schema but with alias,
851        // to avoid colliding with the projection on the column itself if any
852        assert_eq!(
853            inner_projection_exprs,
854            vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
855        );
856
857        // unnest(array_col) + 1
858        let original_expr = unnest(col("array_col")).add(lit(1i64));
859        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
860            &input,
861            &mut unnest_placeholder_columns,
862            &mut inner_projection_exprs,
863            &original_expr,
864        )?;
865        column_unnests_eq(
866            vec![
867                "__unnest_placeholder(struct_col)",
868                "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
869            ],
870            &unnest_placeholder_columns,
871        );
872        // Only transform the unnest children
873        assert_eq!(
874            transformed_exprs,
875            vec![col("__unnest_placeholder(array_col,depth=1)")
876                .alias("UNNEST(array_col)")
877                .add(lit(1i64))]
878        );
879
880        // Keep appending to the current vector
881        // Still reference array_col in original schema but with alias,
882        // to avoid colliding with the projection on the column itself if any
883        assert_eq!(
884            inner_projection_exprs,
885            vec![
886                col("struct_col").alias("__unnest_placeholder(struct_col)"),
887                col("array_col").alias("__unnest_placeholder(array_col)")
888            ]
889        );
890
891        Ok(())
892    }
893
894    // Unnest -> field access -> unnest
895    #[test]
896    fn test_transform_non_consecutive_unnests() -> Result<()> {
897        // List of struct
898        // [struct{'subfield1':list(i64), 'subfield2':list(utf8)}]
899        let schema = Schema::new(vec![
900            Field::new(
901                "struct_list",
902                ArrowDataType::List(Arc::new(Field::new(
903                    "element",
904                    ArrowDataType::Struct(Fields::from(vec![
905                        Field::new(
906                            // list of i64
907                            "subfield1",
908                            ArrowDataType::List(Arc::new(Field::new(
909                                "i64_element",
910                                ArrowDataType::Int64,
911                                true,
912                            ))),
913                            true,
914                        ),
915                        Field::new(
916                            // list of utf8
917                            "subfield2",
918                            ArrowDataType::List(Arc::new(Field::new(
919                                "utf8_element",
920                                ArrowDataType::Utf8,
921                                true,
922                            ))),
923                            true,
924                        ),
925                    ])),
926                    true,
927                ))),
928                true,
929            ),
930            Field::new("int_col", ArrowDataType::Int32, false),
931        ]);
932
933        let dfschema = DFSchema::try_from(schema)?;
934
935        let input = LogicalPlan::EmptyRelation(EmptyRelation {
936            produce_one_row: false,
937            schema: Arc::new(dfschema),
938        });
939
940        let mut unnest_placeholder_columns = IndexMap::new();
941        let mut inner_projection_exprs = vec![];
942
943        // An expr with multiple unnest
944        let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
945        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
946            &input,
947            &mut unnest_placeholder_columns,
948            &mut inner_projection_exprs,
949            &select_expr1,
950        )?;
951        // Only the inner most/ bottom most unnest is transformed
952        assert_eq!(
953            transformed_exprs,
954            vec![unnest(
955                col("__unnest_placeholder(struct_list,depth=1)")
956                    .alias("UNNEST(struct_list)")
957                    .field("subfield1")
958            )]
959        );
960
961        column_unnests_eq(
962            vec![
963                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
964            ],
965            &unnest_placeholder_columns,
966        );
967
968        assert_eq!(
969            inner_projection_exprs,
970            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
971        );
972
973        // continue rewrite another expr in select
974        let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
975        let transformed_exprs = rewrite_recursive_unnest_bottom_up(
976            &input,
977            &mut unnest_placeholder_columns,
978            &mut inner_projection_exprs,
979            &select_expr2,
980        )?;
981        // Only the inner most/ bottom most unnest is transformed
982        assert_eq!(
983            transformed_exprs,
984            vec![unnest(
985                col("__unnest_placeholder(struct_list,depth=1)")
986                    .alias("UNNEST(struct_list)")
987                    .field("subfield2")
988            )]
989        );
990
991        // unnest place holder columns remain the same
992        // because expr1 and expr2 derive from the same unnest result
993        column_unnests_eq(
994            vec![
995                "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
996            ],
997            &unnest_placeholder_columns,
998        );
999
1000        assert_eq!(
1001            inner_projection_exprs,
1002            vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1003        );
1004
1005        Ok(())
1006    }
1007
1008    #[test]
1009    fn test_resolve_positions_to_exprs() -> Result<()> {
1010        let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1011
1012        // Assert 1 resolved as first column in select list
1013        let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1014        assert_eq!(resolved, col("c1"));
1015
1016        // Assert error if index out of select clause bounds
1017        let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1018        assert!(resolved.is_err_and(|e| e.message().contains(
1019            "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1020        )));
1021
1022        let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1023        assert!(resolved.is_err_and(|e| e.message().contains(
1024            "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1025        )));
1026
1027        // Assert expression returned as-is
1028        let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1029        assert_eq!(resolved, lit("text"));
1030
1031        let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1032        assert_eq!(resolved, col("fake"));
1033
1034        Ok(())
1035    }
1036}