datafusion_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression utilities
19
20use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27    and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37    internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap,
38    Result, TableReference,
39};
40
41use indexmap::IndexSet;
42use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
43
44pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
45
46///  The value to which `COUNT(*)` is expanded to in
47///  `COUNT(<constant>)` expressions
48pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
49
50/// Count the number of distinct exprs in a list of group by expressions. If the
51/// first element is a `GroupingSet` expression then it must be the only expr.
52pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
53    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
54        if group_expr.len() > 1 {
55            return plan_err!(
56                "Invalid group by expressions, GroupingSet must be the only expression"
57            );
58        }
59        // Groupings sets have an additional integral column for the grouping id
60        Ok(grouping_set.distinct_expr().len() + 1)
61    } else {
62        grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
63    }
64}
65
66/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
67/// including the empty set and S itself.
68///
69/// Example:
70///
71/// If S is the set {x, y, z}, then all the subsets of S are \
72///  {} \
73///  {x} \
74///  {y} \
75///  {z} \
76///  {x, y} \
77///  {x, z} \
78///  {y, z} \
79///  {x, y, z} \
80///  and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
81///
82/// [power set]: https://en.wikipedia.org/wiki/Power_set
83fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
84    if slice.len() >= 64 {
85        return Err("The size of the set must be less than 64.".into());
86    }
87
88    let mut v = Vec::new();
89    for mask in 0..(1 << slice.len()) {
90        let mut ss = vec![];
91        let mut bitset = mask;
92        while bitset > 0 {
93            let rightmost: u64 = bitset & !(bitset - 1);
94            let idx = rightmost.trailing_zeros();
95            let item = slice.get(idx as usize).unwrap();
96            ss.push(item);
97            // zero the trailing bit
98            bitset &= bitset - 1;
99        }
100        v.push(ss);
101    }
102    Ok(v)
103}
104
105/// check the number of expressions contained in the grouping_set
106fn check_grouping_set_size_limit(size: usize) -> Result<()> {
107    let max_grouping_set_size = 65535;
108    if size > max_grouping_set_size {
109        return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
110    }
111
112    Ok(())
113}
114
115/// check the number of grouping_set contained in the grouping sets
116fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
117    let max_grouping_sets_size = 4096;
118    if size > max_grouping_sets_size {
119        return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
120    }
121
122    Ok(())
123}
124
125/// Merge two grouping_set
126///
127/// # Example
128/// ```text
129/// (A, B), (C, D) -> (A, B, C, D)
130/// ```
131///
132/// # Error
133/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
134///
135/// [`DataFusionError`]: datafusion_common::DataFusionError
136fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
137    check_grouping_set_size_limit(left.len() + right.len())?;
138    Ok(left.iter().chain(right.iter()).cloned().collect())
139}
140
141/// Compute the cross product of two grouping_sets
142///
143/// # Example
144/// ```text
145/// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)]
146/// ```
147///
148/// # Error
149/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
150/// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit
151///
152/// [`DataFusionError`]: datafusion_common::DataFusionError
153fn cross_join_grouping_sets<T: Clone>(
154    left: &[Vec<T>],
155    right: &[Vec<T>],
156) -> Result<Vec<Vec<T>>> {
157    let grouping_sets_size = left.len() * right.len();
158
159    check_grouping_sets_size_limit(grouping_sets_size)?;
160
161    let mut result = Vec::with_capacity(grouping_sets_size);
162    for le in left {
163        for re in right {
164            result.push(merge_grouping_set(le, re)?);
165        }
166    }
167    Ok(result)
168}
169
170/// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\
171/// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\
172/// no conversion will be performed.
173///
174/// e.g.
175///
176/// person.id,\
177/// GROUPING SETS ((person.age, person.salary),(person.age)),\
178/// ROLLUP(person.state, person.birth_date)
179///
180/// =>
181///
182/// GROUPING SETS (\
183///   (person.id, person.age, person.salary),\
184///   (person.id, person.age, person.salary, person.state),\
185///   (person.id, person.age, person.salary, person.state, person.birth_date),\
186///   (person.id, person.age),\
187///   (person.id, person.age, person.state),\
188///   (person.id, person.age, person.state, person.birth_date)\
189/// )
190pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
191    let has_grouping_set = group_expr
192        .iter()
193        .any(|expr| matches!(expr, Expr::GroupingSet(_)));
194    if !has_grouping_set || group_expr.len() == 1 {
195        return Ok(group_expr);
196    }
197    // Only process mix grouping sets
198    let partial_sets = group_expr
199        .iter()
200        .map(|expr| {
201            let exprs = match expr {
202                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
203                    check_grouping_sets_size_limit(grouping_sets.len())?;
204                    grouping_sets.iter().map(|e| e.iter().collect()).collect()
205                }
206                Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
207                    let grouping_sets = powerset(group_exprs)
208                        .map_err(|e| plan_datafusion_err!("{}", e))?;
209                    check_grouping_sets_size_limit(grouping_sets.len())?;
210                    grouping_sets
211                }
212                Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
213                    let size = group_exprs.len();
214                    let slice = group_exprs.as_slice();
215                    check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
216                    (0..(size + 1))
217                        .map(|i| slice[0..i].iter().collect())
218                        .collect()
219                }
220                expr => vec![vec![expr]],
221            };
222            Ok(exprs)
223        })
224        .collect::<Result<Vec<_>>>()?;
225
226    // Cross Join
227    let grouping_sets = partial_sets
228        .into_iter()
229        .map(Ok)
230        .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
231        .transpose()?
232        .map(|e| {
233            e.into_iter()
234                .map(|e| e.into_iter().cloned().collect())
235                .collect()
236        })
237        .unwrap_or_default();
238
239    Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
240        grouping_sets,
241    ))])
242}
243
244/// Find all distinct exprs in a list of group by expressions. If the
245/// first element is a `GroupingSet` expression then it must be the only expr.
246pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
247    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
248        if group_expr.len() > 1 {
249            return plan_err!(
250                "Invalid group by expressions, GroupingSet must be the only expression"
251            );
252        }
253        Ok(grouping_set.distinct_expr())
254    } else {
255        Ok(group_expr
256            .iter()
257            .collect::<IndexSet<_>>()
258            .into_iter()
259            .collect())
260    }
261}
262
263/// Recursively walk an expression tree, collecting the unique set of columns
264/// referenced in the expression
265pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
266    expr.apply(|expr| {
267        match expr {
268            Expr::Column(qc) => {
269                accum.insert(qc.clone());
270            }
271            // Use explicit pattern match instead of a default
272            // implementation, so that in the future if someone adds
273            // new Expr types, they will check here as well
274            // TODO: remove the next line after `Expr::Wildcard` is removed
275            #[expect(deprecated)]
276            Expr::Unnest(_)
277            | Expr::ScalarVariable(_, _)
278            | Expr::Alias(_)
279            | Expr::Literal(_, _)
280            | Expr::BinaryExpr { .. }
281            | Expr::Like { .. }
282            | Expr::SimilarTo { .. }
283            | Expr::Not(_)
284            | Expr::IsNotNull(_)
285            | Expr::IsNull(_)
286            | Expr::IsTrue(_)
287            | Expr::IsFalse(_)
288            | Expr::IsUnknown(_)
289            | Expr::IsNotTrue(_)
290            | Expr::IsNotFalse(_)
291            | Expr::IsNotUnknown(_)
292            | Expr::Negative(_)
293            | Expr::Between { .. }
294            | Expr::Case { .. }
295            | Expr::Cast { .. }
296            | Expr::TryCast { .. }
297            | Expr::ScalarFunction(..)
298            | Expr::WindowFunction { .. }
299            | Expr::AggregateFunction { .. }
300            | Expr::GroupingSet(_)
301            | Expr::InList { .. }
302            | Expr::Exists { .. }
303            | Expr::InSubquery(_)
304            | Expr::ScalarSubquery(_)
305            | Expr::Wildcard { .. }
306            | Expr::Placeholder(_)
307            | Expr::OuterReferenceColumn { .. } => {}
308        }
309        Ok(TreeNodeRecursion::Continue)
310    })
311    .map(|_| ())
312}
313
314/// Find excluded columns in the schema, if any
315/// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]`
316fn get_excluded_columns(
317    opt_exclude: Option<&ExcludeSelectItem>,
318    opt_except: Option<&ExceptSelectItem>,
319    schema: &DFSchema,
320    qualifier: Option<&TableReference>,
321) -> Result<Vec<Column>> {
322    let mut idents = vec![];
323    if let Some(excepts) = opt_except {
324        idents.push(&excepts.first_element);
325        idents.extend(&excepts.additional_elements);
326    }
327    if let Some(exclude) = opt_exclude {
328        match exclude {
329            ExcludeSelectItem::Single(ident) => idents.push(ident),
330            ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
331        }
332    }
333    // Excluded columns should be unique
334    let n_elem = idents.len();
335    let unique_idents = idents.into_iter().collect::<HashSet<_>>();
336    // If HashSet size, and vector length are different, this means that some of the excluded columns
337    // are not unique. In this case return error.
338    if n_elem != unique_idents.len() {
339        return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
340    }
341
342    let mut result = vec![];
343    for ident in unique_idents.into_iter() {
344        let col_name = ident.value.as_str();
345        let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
346        result.push(Column::from((qualifier, field)));
347    }
348    Ok(result)
349}
350
351/// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip`
352fn get_exprs_except_skipped(
353    schema: &DFSchema,
354    columns_to_skip: HashSet<Column>,
355) -> Vec<Expr> {
356    if columns_to_skip.is_empty() {
357        schema.iter().map(Expr::from).collect::<Vec<Expr>>()
358    } else {
359        schema
360            .columns()
361            .iter()
362            .filter_map(|c| {
363                if !columns_to_skip.contains(c) {
364                    Some(Expr::Column(c.clone()))
365                } else {
366                    None
367                }
368            })
369            .collect::<Vec<Expr>>()
370    }
371}
372
373/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice
374/// (once for each join side), but an unqualified wildcard should include it only once.
375/// This function returns the columns that should be excluded.
376fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
377    let using_columns = plan.using_columns()?;
378    let excluded = using_columns
379        .into_iter()
380        // For each USING JOIN condition, only expand to one of each join column in projection
381        .flat_map(|cols| {
382            let mut cols = cols.into_iter().collect::<Vec<_>>();
383            // sort join columns to make sure we consistently keep the same
384            // qualified column
385            cols.sort();
386            let mut out_column_names: HashSet<String> = HashSet::new();
387            cols.into_iter().filter_map(move |c| {
388                if out_column_names.contains(&c.name) {
389                    Some(c)
390                } else {
391                    out_column_names.insert(c.name);
392                    None
393                }
394            })
395        })
396        .collect::<HashSet<_>>();
397    Ok(excluded)
398}
399
400/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
401pub fn expand_wildcard(
402    schema: &DFSchema,
403    plan: &LogicalPlan,
404    wildcard_options: Option<&WildcardOptions>,
405) -> Result<Vec<Expr>> {
406    let mut columns_to_skip = exclude_using_columns(plan)?;
407    let excluded_columns = if let Some(WildcardOptions {
408        exclude: opt_exclude,
409        except: opt_except,
410        ..
411    }) = wildcard_options
412    {
413        get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
414    } else {
415        vec![]
416    };
417    // Add each excluded `Column` to columns_to_skip
418    columns_to_skip.extend(excluded_columns);
419    Ok(get_exprs_except_skipped(schema, columns_to_skip))
420}
421
422/// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s.
423pub fn expand_qualified_wildcard(
424    qualifier: &TableReference,
425    schema: &DFSchema,
426    wildcard_options: Option<&WildcardOptions>,
427) -> Result<Vec<Expr>> {
428    let qualified_indices = schema.fields_indices_with_qualified(qualifier);
429    let projected_func_dependencies = schema
430        .functional_dependencies()
431        .project_functional_dependencies(&qualified_indices, qualified_indices.len());
432    let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
433    if fields_with_qualified.is_empty() {
434        return plan_err!("Invalid qualifier {qualifier}");
435    }
436
437    let qualified_schema = Arc::new(Schema::new_with_metadata(
438        fields_with_qualified,
439        schema.metadata().clone(),
440    ));
441    let qualified_dfschema =
442        DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
443            .with_functional_dependencies(projected_func_dependencies)?;
444    let excluded_columns = if let Some(WildcardOptions {
445        exclude: opt_exclude,
446        except: opt_except,
447        ..
448    }) = wildcard_options
449    {
450        get_excluded_columns(
451            opt_exclude.as_ref(),
452            opt_except.as_ref(),
453            schema,
454            Some(qualifier),
455        )?
456    } else {
457        vec![]
458    };
459    // Add each excluded `Column` to columns_to_skip
460    let mut columns_to_skip = HashSet::new();
461    columns_to_skip.extend(excluded_columns);
462    Ok(get_exprs_except_skipped(
463        &qualified_dfschema,
464        columns_to_skip,
465    ))
466}
467
468/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
469/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
470type WindowSortKey = Vec<(Sort, bool)>;
471
472/// Generate a sort key for a given window expr's partition_by and order_by expr
473pub fn generate_sort_key(
474    partition_by: &[Expr],
475    order_by: &[Sort],
476) -> Result<WindowSortKey> {
477    let normalized_order_by_keys = order_by
478        .iter()
479        .map(|e| {
480            let Sort { expr, .. } = e;
481            Sort::new(expr.clone(), true, false)
482        })
483        .collect::<Vec<_>>();
484
485    let mut final_sort_keys = vec![];
486    let mut is_partition_flag = vec![];
487    partition_by.iter().for_each(|e| {
488        // By default, create sort key with ASC is true and NULLS LAST to be consistent with
489        // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
490        let e = e.clone().sort(true, false);
491        if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
492            let order_by_key = &order_by[pos];
493            if !final_sort_keys.contains(order_by_key) {
494                final_sort_keys.push(order_by_key.clone());
495                is_partition_flag.push(true);
496            }
497        } else if !final_sort_keys.contains(&e) {
498            final_sort_keys.push(e);
499            is_partition_flag.push(true);
500        }
501    });
502
503    order_by.iter().for_each(|e| {
504        if !final_sort_keys.contains(e) {
505            final_sort_keys.push(e.clone());
506            is_partition_flag.push(false);
507        }
508    });
509    let res = final_sort_keys
510        .into_iter()
511        .zip(is_partition_flag)
512        .collect::<Vec<_>>();
513    Ok(res)
514}
515
516/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
517/// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c>
518pub fn compare_sort_expr(
519    sort_expr_a: &Sort,
520    sort_expr_b: &Sort,
521    schema: &DFSchemaRef,
522) -> Ordering {
523    let Sort {
524        expr: expr_a,
525        asc: asc_a,
526        nulls_first: nulls_first_a,
527    } = sort_expr_a;
528
529    let Sort {
530        expr: expr_b,
531        asc: asc_b,
532        nulls_first: nulls_first_b,
533    } = sort_expr_b;
534
535    let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
536    let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
537    for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
538        match idx_a.cmp(idx_b) {
539            Ordering::Less => {
540                return Ordering::Less;
541            }
542            Ordering::Greater => {
543                return Ordering::Greater;
544            }
545            Ordering::Equal => {}
546        }
547    }
548    match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
549        Ordering::Less => return Ordering::Greater,
550        Ordering::Greater => {
551            return Ordering::Less;
552        }
553        Ordering::Equal => {}
554    }
555    match (asc_a, asc_b) {
556        (true, false) => {
557            return Ordering::Greater;
558        }
559        (false, true) => {
560            return Ordering::Less;
561        }
562        _ => {}
563    }
564    match (nulls_first_a, nulls_first_b) {
565        (true, false) => {
566            return Ordering::Less;
567        }
568        (false, true) => {
569            return Ordering::Greater;
570        }
571        _ => {}
572    }
573    Ordering::Equal
574}
575
576/// Group a slice of window expression expr by their order by expressions
577pub fn group_window_expr_by_sort_keys(
578    window_expr: impl IntoIterator<Item = Expr>,
579) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
580    let mut result = vec![];
581    window_expr.into_iter().try_for_each(|expr| match &expr {
582        Expr::WindowFunction(window_fun) => {
583            let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
584            let sort_key = generate_sort_key(partition_by, order_by)?;
585            if let Some((_, values)) = result.iter_mut().find(
586                |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
587            ) {
588                values.push(expr);
589            } else {
590                result.push((sort_key, vec![expr]))
591            }
592            Ok(())
593        }
594        other => internal_err!(
595            "Impossibly got non-window expr {other:?}"
596        ),
597    })?;
598    Ok(result)
599}
600
601/// Collect all deeply nested `Expr::AggregateFunction`.
602/// They are returned in order of occurrence (depth
603/// first), with duplicates omitted.
604pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
605    find_exprs_in_exprs(exprs, &|nested_expr| {
606        matches!(nested_expr, Expr::AggregateFunction { .. })
607    })
608}
609
610/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
611/// (depth first), with duplicates omitted.
612pub fn find_window_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
613    find_exprs_in_exprs(exprs, &|nested_expr| {
614        matches!(nested_expr, Expr::WindowFunction { .. })
615    })
616}
617
618/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence
619/// (depth first), with duplicates omitted.
620pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
621    find_exprs_in_expr(expr, &|nested_expr| {
622        matches!(nested_expr, Expr::OuterReferenceColumn { .. })
623    })
624}
625
626/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
627/// pass the provided test. The returned `Expr`'s are deduplicated and returned
628/// in order of appearance (depth first).
629fn find_exprs_in_exprs<'a, F>(
630    exprs: impl IntoIterator<Item = &'a Expr>,
631    test_fn: &F,
632) -> Vec<Expr>
633where
634    F: Fn(&Expr) -> bool,
635{
636    exprs
637        .into_iter()
638        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
639        .fold(vec![], |mut acc, expr| {
640            if !acc.contains(&expr) {
641                acc.push(expr)
642            }
643            acc
644        })
645}
646
647/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
648/// provided test. The returned `Expr`'s are deduplicated and returned in order
649/// of appearance (depth first).
650fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
651where
652    F: Fn(&Expr) -> bool,
653{
654    let mut exprs = vec![];
655    expr.apply(|expr| {
656        if test_fn(expr) {
657            if !(exprs.contains(expr)) {
658                exprs.push(expr.clone())
659            }
660            // Stop recursing down this expr once we find a match
661            return Ok(TreeNodeRecursion::Jump);
662        }
663
664        Ok(TreeNodeRecursion::Continue)
665    })
666    // pre_visit always returns OK, so this will always too
667    .expect("no way to return error during recursion");
668    exprs
669}
670
671/// Recursively inspect an [`Expr`] and all its children.
672pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
673where
674    F: FnMut(&Expr) -> Result<(), E>,
675{
676    let mut err = Ok(());
677    expr.apply(|expr| {
678        if let Err(e) = f(expr) {
679            // Save the error for later (it may not be a DataFusionError)
680            err = Err(e);
681            Ok(TreeNodeRecursion::Stop)
682        } else {
683            // keep going
684            Ok(TreeNodeRecursion::Continue)
685        }
686    })
687    // The closure always returns OK, so this will always too
688    .expect("no way to return error during recursion");
689
690    err
691}
692
693/// Create field meta-data from an expression, for use in a result set schema
694pub fn exprlist_to_fields<'a>(
695    exprs: impl IntoIterator<Item = &'a Expr>,
696    plan: &LogicalPlan,
697) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
698    // Look for exact match in plan's output schema
699    let input_schema = plan.schema();
700    exprs
701        .into_iter()
702        .map(|e| e.to_field(input_schema))
703        .collect()
704}
705
706/// Convert an expression into Column expression if it's already provided as input plan.
707///
708/// For example, it rewrites:
709///
710/// ```text
711/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
712/// .project(vec![col("c1"), sum(col("c2"))?
713/// ```
714///
715/// Into:
716///
717/// ```text
718/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
719/// .project(vec![col("c1"), col("SUM(c2)")?
720/// ```
721pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
722    let output_exprs = match input.columnized_output_exprs() {
723        Ok(exprs) if !exprs.is_empty() => exprs,
724        _ => return Ok(e),
725    };
726    let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
727    e.transform_down(|node: Expr| match exprs_map.get(&node) {
728        Some(column) => Ok(Transformed::new(
729            Expr::Column(column.clone()),
730            true,
731            TreeNodeRecursion::Jump,
732        )),
733        None => Ok(Transformed::no(node)),
734    })
735    .data()
736}
737
738/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
739/// appearance (depth first), and may contain duplicates.
740pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
741    exprs
742        .iter()
743        .flat_map(find_columns_referenced_by_expr)
744        .map(Expr::Column)
745        .collect()
746}
747
748pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
749    let mut exprs = vec![];
750    e.apply(|expr| {
751        if let Expr::Column(c) = expr {
752            exprs.push(c.clone())
753        }
754        Ok(TreeNodeRecursion::Continue)
755    })
756    // As the closure always returns Ok, this "can't" error
757    .expect("Unexpected error");
758    exprs
759}
760
761/// Convert any `Expr` to an `Expr::Column`.
762pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
763    match expr {
764        Expr::Column(col) => {
765            let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
766            Ok(Expr::from(Column::from((qualifier, field))))
767        }
768        _ => Ok(Expr::Column(Column::from_name(
769            expr.schema_name().to_string(),
770        ))),
771    }
772}
773
774/// Recursively walk an expression tree, collecting the column indexes
775/// referenced in the expression
776pub(crate) fn find_column_indexes_referenced_by_expr(
777    e: &Expr,
778    schema: &DFSchemaRef,
779) -> Vec<usize> {
780    let mut indexes = vec![];
781    e.apply(|expr| {
782        match expr {
783            Expr::Column(qc) => {
784                if let Ok(idx) = schema.index_of_column(qc) {
785                    indexes.push(idx);
786                }
787            }
788            Expr::Literal(_, _) => {
789                indexes.push(usize::MAX);
790            }
791            _ => {}
792        }
793        Ok(TreeNodeRecursion::Continue)
794    })
795    .unwrap();
796    indexes
797}
798
799/// Can this data type be used in hash join equal conditions??
800/// Data types here come from function 'equal_rows', if more data types are supported
801/// in create_hashes, add those data types here to generate join logical plan.
802pub fn can_hash(data_type: &DataType) -> bool {
803    match data_type {
804        DataType::Null => true,
805        DataType::Boolean => true,
806        DataType::Int8 => true,
807        DataType::Int16 => true,
808        DataType::Int32 => true,
809        DataType::Int64 => true,
810        DataType::UInt8 => true,
811        DataType::UInt16 => true,
812        DataType::UInt32 => true,
813        DataType::UInt64 => true,
814        DataType::Float16 => true,
815        DataType::Float32 => true,
816        DataType::Float64 => true,
817        DataType::Decimal32(_, _) => true,
818        DataType::Decimal64(_, _) => true,
819        DataType::Decimal128(_, _) => true,
820        DataType::Decimal256(_, _) => true,
821        DataType::Timestamp(_, _) => true,
822        DataType::Utf8 => true,
823        DataType::LargeUtf8 => true,
824        DataType::Utf8View => true,
825        DataType::Binary => true,
826        DataType::LargeBinary => true,
827        DataType::BinaryView => true,
828        DataType::Date32 => true,
829        DataType::Date64 => true,
830        DataType::Time32(_) => true,
831        DataType::Time64(_) => true,
832        DataType::Duration(_) => true,
833        DataType::Interval(_) => true,
834        DataType::FixedSizeBinary(_) => true,
835        DataType::Dictionary(key_type, value_type) => {
836            DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
837        }
838        DataType::List(value_type) => can_hash(value_type.data_type()),
839        DataType::LargeList(value_type) => can_hash(value_type.data_type()),
840        DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
841        DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
842        DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
843
844        DataType::ListView(_)
845        | DataType::LargeListView(_)
846        | DataType::Union(_, _)
847        | DataType::RunEndEncoded(_, _) => false,
848    }
849}
850
851/// Check whether all columns are from the schema.
852pub fn check_all_columns_from_schema(
853    columns: &HashSet<&Column>,
854    schema: &DFSchema,
855) -> Result<bool> {
856    for col in columns.iter() {
857        let exist = schema.is_column_from_schema(col);
858        if !exist {
859            return Ok(false);
860        }
861    }
862
863    Ok(true)
864}
865
866/// Give two sides of the equijoin predicate, return a valid join key pair.
867/// If there is no valid join key pair, return None.
868///
869/// A valid join means:
870/// 1. All referenced column of the left side is from the left schema, and
871///    all referenced column of the right side is from the right schema.
872/// 2. Or opposite. All referenced column of the left side is from the right schema,
873///    and the right side is from the left schema.
874///
875pub fn find_valid_equijoin_key_pair(
876    left_key: &Expr,
877    right_key: &Expr,
878    left_schema: &DFSchema,
879    right_schema: &DFSchema,
880) -> Result<Option<(Expr, Expr)>> {
881    let left_using_columns = left_key.column_refs();
882    let right_using_columns = right_key.column_refs();
883
884    // Conditions like a = 10, will be added to non-equijoin.
885    if left_using_columns.is_empty() || right_using_columns.is_empty() {
886        return Ok(None);
887    }
888
889    if check_all_columns_from_schema(&left_using_columns, left_schema)?
890        && check_all_columns_from_schema(&right_using_columns, right_schema)?
891    {
892        return Ok(Some((left_key.clone(), right_key.clone())));
893    } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
894        && check_all_columns_from_schema(&left_using_columns, right_schema)?
895    {
896        return Ok(Some((right_key.clone(), left_key.clone())));
897    }
898
899    Ok(None)
900}
901
902/// Creates a detailed error message for a function with wrong signature.
903///
904/// For example, a query like `select round(3.14, 1.1);` would yield:
905/// ```text
906/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
907///     Candidate functions:
908///     round(Float64, Int64)
909///     round(Float32, Int64)
910///     round(Float64)
911///     round(Float32)
912/// ```
913pub fn generate_signature_error_msg(
914    func_name: &str,
915    func_signature: Signature,
916    input_expr_types: &[DataType],
917) -> String {
918    let candidate_signatures = func_signature
919        .type_signature
920        .to_string_repr()
921        .iter()
922        .map(|args_str| format!("\t{func_name}({args_str})"))
923        .collect::<Vec<String>>()
924        .join("\n");
925
926    format!(
927            "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
928            func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
929        )
930}
931
932/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
933///
934/// See [`split_conjunction_owned`] for more details and an example.
935pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
936    split_conjunction_impl(expr, vec![])
937}
938
939fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
940    match expr {
941        Expr::BinaryExpr(BinaryExpr {
942            right,
943            op: Operator::And,
944            left,
945        }) => {
946            let exprs = split_conjunction_impl(left, exprs);
947            split_conjunction_impl(right, exprs)
948        }
949        Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
950        other => {
951            exprs.push(other);
952            exprs
953        }
954    }
955}
956
957/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
958///
959/// See [`split_conjunction_owned`] for more details and an example.
960pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
961    let mut stack = vec![expr];
962    std::iter::from_fn(move || {
963        while let Some(expr) = stack.pop() {
964            match expr {
965                Expr::BinaryExpr(BinaryExpr {
966                    right,
967                    op: Operator::And,
968                    left,
969                }) => {
970                    stack.push(right);
971                    stack.push(left);
972                }
973                Expr::Alias(Alias { expr, .. }) => stack.push(expr),
974                other => return Some(other),
975            }
976        }
977        None
978    })
979}
980
981/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
982///
983/// See [`split_conjunction_owned`] for more details and an example.
984pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
985    let mut stack = vec![expr];
986    std::iter::from_fn(move || {
987        while let Some(expr) = stack.pop() {
988            match expr {
989                Expr::BinaryExpr(BinaryExpr {
990                    right,
991                    op: Operator::And,
992                    left,
993                }) => {
994                    stack.push(*right);
995                    stack.push(*left);
996                }
997                Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
998                other => return Some(other),
999            }
1000        }
1001        None
1002    })
1003}
1004
1005/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1006///
1007/// This is often used to "split" filter expressions such as `col1 = 5
1008/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1009///
1010/// # Example
1011/// ```
1012/// # use datafusion_expr::{col, lit};
1013/// # use datafusion_expr::utils::split_conjunction_owned;
1014/// // a=1 AND b=2
1015/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1016///
1017/// // [a=1, b=2]
1018/// let split = vec![
1019///   col("a").eq(lit(1)),
1020///   col("b").eq(lit(2)),
1021/// ];
1022///
1023/// // use split_conjunction_owned to split them
1024/// assert_eq!(split_conjunction_owned(expr), split);
1025/// ```
1026pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1027    split_binary_owned(expr, Operator::And)
1028}
1029
1030/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1031///
1032/// This is often used to "split" expressions such as `col1 = 5
1033/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1034///
1035/// # Example
1036/// ```
1037/// # use datafusion_expr::{col, lit, Operator};
1038/// # use datafusion_expr::utils::split_binary_owned;
1039/// # use std::ops::Add;
1040/// // a=1 + b=2
1041/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
1042///
1043/// // [a=1, b=2]
1044/// let split = vec![
1045///   col("a").eq(lit(1)),
1046///   col("b").eq(lit(2)),
1047/// ];
1048///
1049/// // use split_binary_owned to split them
1050/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
1051/// ```
1052pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1053    split_binary_owned_impl(expr, op, vec![])
1054}
1055
1056fn split_binary_owned_impl(
1057    expr: Expr,
1058    operator: Operator,
1059    mut exprs: Vec<Expr>,
1060) -> Vec<Expr> {
1061    match expr {
1062        Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1063            let exprs = split_binary_owned_impl(*left, operator, exprs);
1064            split_binary_owned_impl(*right, operator, exprs)
1065        }
1066        Expr::Alias(Alias { expr, .. }) => {
1067            split_binary_owned_impl(*expr, operator, exprs)
1068        }
1069        other => {
1070            exprs.push(other);
1071            exprs
1072        }
1073    }
1074}
1075
1076/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1077///
1078/// See [`split_binary_owned`] for more details and an example.
1079pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1080    split_binary_impl(expr, op, vec![])
1081}
1082
1083fn split_binary_impl<'a>(
1084    expr: &'a Expr,
1085    operator: Operator,
1086    mut exprs: Vec<&'a Expr>,
1087) -> Vec<&'a Expr> {
1088    match expr {
1089        Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1090            let exprs = split_binary_impl(left, operator, exprs);
1091            split_binary_impl(right, operator, exprs)
1092        }
1093        Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1094        other => {
1095            exprs.push(other);
1096            exprs
1097        }
1098    }
1099}
1100
1101/// Combines an array of filter expressions into a single filter
1102/// expression consisting of the input filter expressions joined with
1103/// logical AND.
1104///
1105/// Returns None if the filters array is empty.
1106///
1107/// # Example
1108/// ```
1109/// # use datafusion_expr::{col, lit};
1110/// # use datafusion_expr::utils::conjunction;
1111/// // a=1 AND b=2
1112/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1113///
1114/// // [a=1, b=2]
1115/// let split = vec![
1116///   col("a").eq(lit(1)),
1117///   col("b").eq(lit(2)),
1118/// ];
1119///
1120/// // use conjunction to join them together with `AND`
1121/// assert_eq!(conjunction(split), Some(expr));
1122/// ```
1123pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1124    filters.into_iter().reduce(Expr::and)
1125}
1126
1127/// Combines an array of filter expressions into a single filter
1128/// expression consisting of the input filter expressions joined with
1129/// logical OR.
1130///
1131/// Returns None if the filters array is empty.
1132///
1133/// # Example
1134/// ```
1135/// # use datafusion_expr::{col, lit};
1136/// # use datafusion_expr::utils::disjunction;
1137/// // a=1 OR b=2
1138/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
1139///
1140/// // [a=1, b=2]
1141/// let split = vec![
1142///   col("a").eq(lit(1)),
1143///   col("b").eq(lit(2)),
1144/// ];
1145///
1146/// // use disjunction to join them together with `OR`
1147/// assert_eq!(disjunction(split), Some(expr));
1148/// ```
1149pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1150    filters.into_iter().reduce(Expr::or)
1151}
1152
1153/// Returns a new [LogicalPlan] that filters the output of  `plan` with a
1154/// [LogicalPlan::Filter] with all `predicates` ANDed.
1155///
1156/// # Example
1157/// Before:
1158/// ```text
1159/// plan
1160/// ```
1161///
1162/// After:
1163/// ```text
1164/// Filter(predicate)
1165///   plan
1166/// ```
1167pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1168    // reduce filters to a single filter with an AND
1169    let predicate = predicates
1170        .iter()
1171        .skip(1)
1172        .fold(predicates[0].clone(), |acc, predicate| {
1173            and(acc, (*predicate).to_owned())
1174        });
1175
1176    Ok(LogicalPlan::Filter(Filter::try_new(
1177        predicate,
1178        Arc::new(plan),
1179    )?))
1180}
1181
1182/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and
1183/// one not in the subquery (closed upon from outer scope)
1184///
1185/// # Arguments
1186///
1187/// * `exprs` - List of expressions that may or may not be joins
1188///
1189/// # Return value
1190///
1191/// Tuple of (expressions containing joins, remaining non-join expressions)
1192pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1193    let mut joins = vec![];
1194    let mut others = vec![];
1195    for filter in exprs.into_iter() {
1196        // If the expression contains correlated predicates, add it to join filters
1197        if filter.contains_outer() {
1198            if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1199            {
1200                joins.push(strip_outer_reference((*filter).clone()));
1201            }
1202        } else {
1203            others.push((*filter).clone());
1204        }
1205    }
1206
1207    Ok((joins, others))
1208}
1209
1210/// Returns the first (and only) element in a slice, or an error
1211///
1212/// # Arguments
1213///
1214/// * `slice` - The slice to extract from
1215///
1216/// # Return value
1217///
1218/// The first element, or an error
1219pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1220    match slice {
1221        [it] => Ok(it),
1222        [] => plan_err!("No items found!"),
1223        _ => plan_err!("More than one item found!"),
1224    }
1225}
1226
1227/// merge inputs schema into a single schema.
1228///
1229/// This function merges schemas from multiple logical plan inputs using [`DFSchema::merge`].
1230/// Refer to that documentation for details on precedence and metadata handling.
1231pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1232    if inputs.len() == 1 {
1233        inputs[0].schema().as_ref().clone()
1234    } else {
1235        inputs.iter().map(|input| input.schema()).fold(
1236            DFSchema::empty(),
1237            |mut lhs, rhs| {
1238                lhs.merge(rhs);
1239                lhs
1240            },
1241        )
1242    }
1243}
1244
1245/// Build state name. State is the intermediate state of the aggregate function.
1246pub fn format_state_name(name: &str, state_name: &str) -> String {
1247    format!("{name}[{state_name}]")
1248}
1249
1250/// Determine the set of [`Column`]s produced by the subquery.
1251pub fn collect_subquery_cols(
1252    exprs: &[Expr],
1253    subquery_schema: &DFSchema,
1254) -> Result<BTreeSet<Column>> {
1255    exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1256        let mut using_cols: Vec<Column> = vec![];
1257        for col in expr.column_refs().into_iter() {
1258            if subquery_schema.has_column(col) {
1259                using_cols.push(col.clone());
1260            }
1261        }
1262
1263        cols.extend(using_cols);
1264        Result::<_>::Ok(cols)
1265    })
1266}
1267
1268#[cfg(test)]
1269mod tests {
1270    use super::*;
1271    use crate::{
1272        col, cube,
1273        expr::WindowFunction,
1274        expr_vec_fmt, grouping_set, lit, rollup,
1275        test::function_stub::{max_udaf, min_udaf, sum_udaf},
1276        Cast, ExprFunctionExt, WindowFunctionDefinition,
1277    };
1278    use arrow::datatypes::{UnionFields, UnionMode};
1279
1280    #[test]
1281    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1282        let result = group_window_expr_by_sort_keys(vec![])?;
1283        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1284        assert_eq!(expected, result);
1285        Ok(())
1286    }
1287
1288    #[test]
1289    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1290        let max1 = Expr::from(WindowFunction::new(
1291            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1292            vec![col("name")],
1293        ));
1294        let max2 = Expr::from(WindowFunction::new(
1295            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1296            vec![col("name")],
1297        ));
1298        let min3 = Expr::from(WindowFunction::new(
1299            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1300            vec![col("name")],
1301        ));
1302        let sum4 = Expr::from(WindowFunction::new(
1303            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1304            vec![col("age")],
1305        ));
1306        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1307        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1308        let key = vec![];
1309        let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1310            vec![(key, vec![max1, max2, min3, sum4])];
1311        assert_eq!(expected, result);
1312        Ok(())
1313    }
1314
1315    #[test]
1316    fn test_group_window_expr_by_sort_keys() -> Result<()> {
1317        let age_asc = Sort::new(col("age"), true, true);
1318        let name_desc = Sort::new(col("name"), false, true);
1319        let created_at_desc = Sort::new(col("created_at"), false, true);
1320        let max1 = Expr::from(WindowFunction::new(
1321            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1322            vec![col("name")],
1323        ))
1324        .order_by(vec![age_asc.clone(), name_desc.clone()])
1325        .build()
1326        .unwrap();
1327        let max2 = Expr::from(WindowFunction::new(
1328            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1329            vec![col("name")],
1330        ));
1331        let min3 = Expr::from(WindowFunction::new(
1332            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1333            vec![col("name")],
1334        ))
1335        .order_by(vec![age_asc.clone(), name_desc.clone()])
1336        .build()
1337        .unwrap();
1338        let sum4 = Expr::from(WindowFunction::new(
1339            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1340            vec![col("age")],
1341        ))
1342        .order_by(vec![
1343            name_desc.clone(),
1344            age_asc.clone(),
1345            created_at_desc.clone(),
1346        ])
1347        .build()
1348        .unwrap();
1349        // FIXME use as_ref
1350        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1351        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1352
1353        let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1354        let key2 = vec![];
1355        let key3 = vec![
1356            (name_desc, false),
1357            (age_asc, false),
1358            (created_at_desc, false),
1359        ];
1360
1361        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1362            (key1, vec![max1, min3]),
1363            (key2, vec![max2]),
1364            (key3, vec![sum4]),
1365        ];
1366        assert_eq!(expected, result);
1367        Ok(())
1368    }
1369
1370    #[test]
1371    fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1372        let asc_or_desc = [true, false];
1373        let nulls_first_or_last = [true, false];
1374        let partition_by = &[col("age"), col("name"), col("created_at")];
1375        for asc_ in asc_or_desc {
1376            for nulls_first_ in nulls_first_or_last {
1377                let order_by = &[
1378                    Sort {
1379                        expr: col("age"),
1380                        asc: asc_,
1381                        nulls_first: nulls_first_,
1382                    },
1383                    Sort {
1384                        expr: col("name"),
1385                        asc: asc_,
1386                        nulls_first: nulls_first_,
1387                    },
1388                ];
1389
1390                let expected = vec![
1391                    (
1392                        Sort {
1393                            expr: col("age"),
1394                            asc: asc_,
1395                            nulls_first: nulls_first_,
1396                        },
1397                        true,
1398                    ),
1399                    (
1400                        Sort {
1401                            expr: col("name"),
1402                            asc: asc_,
1403                            nulls_first: nulls_first_,
1404                        },
1405                        true,
1406                    ),
1407                    (
1408                        Sort {
1409                            expr: col("created_at"),
1410                            asc: true,
1411                            nulls_first: false,
1412                        },
1413                        true,
1414                    ),
1415                ];
1416                let result = generate_sort_key(partition_by, order_by)?;
1417                assert_eq!(expected, result);
1418            }
1419        }
1420        Ok(())
1421    }
1422
1423    #[test]
1424    fn test_enumerate_grouping_sets() -> Result<()> {
1425        let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1426        let simple_col = col("simple_col");
1427        let cube = cube(multi_cols.clone());
1428        let rollup = rollup(multi_cols.clone());
1429        let grouping_set = grouping_set(vec![multi_cols]);
1430
1431        // 1. col
1432        let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1433        let result = format!("[{}]", expr_vec_fmt!(sets));
1434        assert_eq!("[simple_col]", &result);
1435
1436        // 2. cube
1437        let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1438        let result = format!("[{}]", expr_vec_fmt!(sets));
1439        assert_eq!("[CUBE (col1, col2, col3)]", &result);
1440
1441        // 3. rollup
1442        let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1443        let result = format!("[{}]", expr_vec_fmt!(sets));
1444        assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1445
1446        // 4. col + cube
1447        let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1448        let result = format!("[{}]", expr_vec_fmt!(sets));
1449        assert_eq!(
1450            "[GROUPING SETS (\
1451            (simple_col), \
1452            (simple_col, col1), \
1453            (simple_col, col2), \
1454            (simple_col, col1, col2), \
1455            (simple_col, col3), \
1456            (simple_col, col1, col3), \
1457            (simple_col, col2, col3), \
1458            (simple_col, col1, col2, col3))]",
1459            &result
1460        );
1461
1462        // 5. col + rollup
1463        let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1464        let result = format!("[{}]", expr_vec_fmt!(sets));
1465        assert_eq!(
1466            "[GROUPING SETS (\
1467            (simple_col), \
1468            (simple_col, col1), \
1469            (simple_col, col1, col2), \
1470            (simple_col, col1, col2, col3))]",
1471            &result
1472        );
1473
1474        // 6. col + grouping_set
1475        let sets =
1476            enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1477        let result = format!("[{}]", expr_vec_fmt!(sets));
1478        assert_eq!(
1479            "[GROUPING SETS (\
1480            (simple_col, col1, col2, col3))]",
1481            &result
1482        );
1483
1484        // 7. col + grouping_set + rollup
1485        let sets = enumerate_grouping_sets(vec![
1486            simple_col.clone(),
1487            grouping_set,
1488            rollup.clone(),
1489        ])?;
1490        let result = format!("[{}]", expr_vec_fmt!(sets));
1491        assert_eq!(
1492            "[GROUPING SETS (\
1493            (simple_col, col1, col2, col3), \
1494            (simple_col, col1, col2, col3, col1), \
1495            (simple_col, col1, col2, col3, col1, col2), \
1496            (simple_col, col1, col2, col3, col1, col2, col3))]",
1497            &result
1498        );
1499
1500        // 8. col + cube + rollup
1501        let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1502        let result = format!("[{}]", expr_vec_fmt!(sets));
1503        assert_eq!(
1504            "[GROUPING SETS (\
1505            (simple_col), \
1506            (simple_col, col1), \
1507            (simple_col, col1, col2), \
1508            (simple_col, col1, col2, col3), \
1509            (simple_col, col1), \
1510            (simple_col, col1, col1), \
1511            (simple_col, col1, col1, col2), \
1512            (simple_col, col1, col1, col2, col3), \
1513            (simple_col, col2), \
1514            (simple_col, col2, col1), \
1515            (simple_col, col2, col1, col2), \
1516            (simple_col, col2, col1, col2, col3), \
1517            (simple_col, col1, col2), \
1518            (simple_col, col1, col2, col1), \
1519            (simple_col, col1, col2, col1, col2), \
1520            (simple_col, col1, col2, col1, col2, col3), \
1521            (simple_col, col3), \
1522            (simple_col, col3, col1), \
1523            (simple_col, col3, col1, col2), \
1524            (simple_col, col3, col1, col2, col3), \
1525            (simple_col, col1, col3), \
1526            (simple_col, col1, col3, col1), \
1527            (simple_col, col1, col3, col1, col2), \
1528            (simple_col, col1, col3, col1, col2, col3), \
1529            (simple_col, col2, col3), \
1530            (simple_col, col2, col3, col1), \
1531            (simple_col, col2, col3, col1, col2), \
1532            (simple_col, col2, col3, col1, col2, col3), \
1533            (simple_col, col1, col2, col3), \
1534            (simple_col, col1, col2, col3, col1), \
1535            (simple_col, col1, col2, col3, col1, col2), \
1536            (simple_col, col1, col2, col3, col1, col2, col3))]",
1537            &result
1538        );
1539
1540        Ok(())
1541    }
1542    #[test]
1543    fn test_split_conjunction() {
1544        let expr = col("a");
1545        let result = split_conjunction(&expr);
1546        assert_eq!(result, vec![&expr]);
1547    }
1548
1549    #[test]
1550    fn test_split_conjunction_two() {
1551        let expr = col("a").eq(lit(5)).and(col("b"));
1552        let expr1 = col("a").eq(lit(5));
1553        let expr2 = col("b");
1554
1555        let result = split_conjunction(&expr);
1556        assert_eq!(result, vec![&expr1, &expr2]);
1557    }
1558
1559    #[test]
1560    fn test_split_conjunction_alias() {
1561        let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1562        let expr1 = col("a").eq(lit(5));
1563        let expr2 = col("b"); // has no alias
1564
1565        let result = split_conjunction(&expr);
1566        assert_eq!(result, vec![&expr1, &expr2]);
1567    }
1568
1569    #[test]
1570    fn test_split_conjunction_or() {
1571        let expr = col("a").eq(lit(5)).or(col("b"));
1572        let result = split_conjunction(&expr);
1573        assert_eq!(result, vec![&expr]);
1574    }
1575
1576    #[test]
1577    fn test_split_binary_owned() {
1578        let expr = col("a");
1579        assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1580    }
1581
1582    #[test]
1583    fn test_split_binary_owned_two() {
1584        assert_eq!(
1585            split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1586            vec![col("a").eq(lit(5)), col("b")]
1587        );
1588    }
1589
1590    #[test]
1591    fn test_split_binary_owned_different_op() {
1592        let expr = col("a").eq(lit(5)).or(col("b"));
1593        assert_eq!(
1594            // expr is connected by OR, but pass in AND
1595            split_binary_owned(expr.clone(), Operator::And),
1596            vec![expr]
1597        );
1598    }
1599
1600    #[test]
1601    fn test_split_conjunction_owned() {
1602        let expr = col("a");
1603        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1604    }
1605
1606    #[test]
1607    fn test_split_conjunction_owned_two() {
1608        assert_eq!(
1609            split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1610            vec![col("a").eq(lit(5)), col("b")]
1611        );
1612    }
1613
1614    #[test]
1615    fn test_split_conjunction_owned_alias() {
1616        assert_eq!(
1617            split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1618            vec![
1619                col("a").eq(lit(5)),
1620                // no alias on b
1621                col("b"),
1622            ]
1623        );
1624    }
1625
1626    #[test]
1627    fn test_conjunction_empty() {
1628        assert_eq!(conjunction(vec![]), None);
1629    }
1630
1631    #[test]
1632    fn test_conjunction() {
1633        // `[A, B, C]`
1634        let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1635
1636        // --> `(A AND B) AND C`
1637        assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1638
1639        // which is different than `A AND (B AND C)`
1640        assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1641    }
1642
1643    #[test]
1644    fn test_disjunction_empty() {
1645        assert_eq!(disjunction(vec![]), None);
1646    }
1647
1648    #[test]
1649    fn test_disjunction() {
1650        // `[A, B, C]`
1651        let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1652
1653        // --> `(A OR B) OR C`
1654        assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1655
1656        // which is different than `A OR (B OR C)`
1657        assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1658    }
1659
1660    #[test]
1661    fn test_split_conjunction_owned_or() {
1662        let expr = col("a").eq(lit(5)).or(col("b"));
1663        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1664    }
1665
1666    #[test]
1667    fn test_collect_expr() -> Result<()> {
1668        let mut accum: HashSet<Column> = HashSet::new();
1669        expr_to_columns(
1670            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1671            &mut accum,
1672        )?;
1673        expr_to_columns(
1674            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1675            &mut accum,
1676        )?;
1677        assert_eq!(1, accum.len());
1678        assert!(accum.contains(&Column::from_name("a")));
1679        Ok(())
1680    }
1681
1682    #[test]
1683    fn test_can_hash() {
1684        let union_fields: UnionFields = [
1685            (0, Arc::new(Field::new("A", DataType::Int32, true))),
1686            (1, Arc::new(Field::new("B", DataType::Float64, true))),
1687        ]
1688        .into_iter()
1689        .collect();
1690
1691        let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1692        assert!(!can_hash(&union_type));
1693
1694        let list_union_type =
1695            DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1696        assert!(!can_hash(&list_union_type));
1697    }
1698}