Skip to main content

datafusion_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression utilities
19
20use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27    BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, and,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37    Column, DFSchema, DFSchemaRef, HashMap, Result, TableReference, internal_err,
38    plan_err,
39};
40
41#[cfg(not(feature = "sql"))]
42use crate::sql::{ExceptSelectItem, ExcludeSelectItem, Ident, ObjectName};
43use indexmap::IndexSet;
44#[cfg(feature = "sql")]
45use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, Ident, ObjectName};
46
47pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
48
49///  The value to which `COUNT(*)` is expanded to in
50///  `COUNT(<constant>)` expressions
51pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
52
53/// Count the number of distinct exprs in a list of group by expressions. If the
54/// first element is a `GroupingSet` expression then it must be the only expr.
55pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
56    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
57        if group_expr.len() > 1 {
58            return plan_err!(
59                "Invalid group by expressions, GroupingSet must be the only expression"
60            );
61        }
62        // Groupings sets have an additional integral column for the grouping id
63        Ok(grouping_set.distinct_expr().len() + 1)
64    } else {
65        grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
66    }
67}
68
69/// Internal helper that generates indices for powerset subsets using bitset iteration.
70/// Returns an iterator of index vectors, where each vector contains the indices
71/// of elements to include in that subset.
72fn powerset_indices(len: usize) -> impl Iterator<Item = Vec<usize>> {
73    (0..(1 << len)).map(move |mask| {
74        let mut indices = vec![];
75        let mut bitset = mask;
76        while bitset > 0 {
77            let rightmost: u64 = bitset & !(bitset - 1);
78            let idx = rightmost.trailing_zeros() as usize;
79            indices.push(idx);
80            bitset &= bitset - 1;
81        }
82        indices
83    })
84}
85
86/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
87/// including the empty set and S itself.
88///
89/// Example:
90///
91/// If S is the set {x, y, z}, then all the subsets of S are \
92///  {} \
93///  {x} \
94///  {y} \
95///  {z} \
96///  {x, y} \
97///  {x, z} \
98///  {y, z} \
99///  {x, y, z} \
100///  and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
101///
102/// [power set]: https://en.wikipedia.org/wiki/Power_set
103pub fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>> {
104    if slice.len() >= 64 {
105        return plan_err!("The size of the set must be less than 64");
106    }
107
108    Ok(powerset_indices(slice.len())
109        .map(|indices| indices.iter().map(|&idx| &slice[idx]).collect())
110        .collect())
111}
112
113/// check the number of expressions contained in the grouping_set
114fn check_grouping_set_size_limit(size: usize) -> Result<()> {
115    let max_grouping_set_size = 65535;
116    if size > max_grouping_set_size {
117        return plan_err!(
118            "The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"
119        );
120    }
121
122    Ok(())
123}
124
125/// check the number of grouping_set contained in the grouping sets
126fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
127    let max_grouping_sets_size = 4096;
128    if size > max_grouping_sets_size {
129        return plan_err!(
130            "The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"
131        );
132    }
133
134    Ok(())
135}
136
137/// Merge two grouping_set
138///
139/// # Example
140/// ```text
141/// (A, B), (C, D) -> (A, B, C, D)
142/// ```
143///
144/// # Error
145/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
146///
147/// [`DataFusionError`]: datafusion_common::DataFusionError
148fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
149    check_grouping_set_size_limit(left.len() + right.len())?;
150    Ok(left.iter().chain(right.iter()).cloned().collect())
151}
152
153/// Compute the cross product of two grouping_sets
154///
155/// # Example
156/// ```text
157/// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)]
158/// ```
159///
160/// # Error
161/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
162/// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit
163///
164/// [`DataFusionError`]: datafusion_common::DataFusionError
165fn cross_join_grouping_sets<T: Clone>(
166    left: &[Vec<T>],
167    right: &[Vec<T>],
168) -> Result<Vec<Vec<T>>> {
169    let grouping_sets_size = left.len() * right.len();
170
171    check_grouping_sets_size_limit(grouping_sets_size)?;
172
173    let mut result = Vec::with_capacity(grouping_sets_size);
174    for le in left {
175        for re in right {
176            result.push(merge_grouping_set(le, re)?);
177        }
178    }
179    Ok(result)
180}
181
182/// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\
183/// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\
184/// no conversion will be performed.
185///
186/// e.g.
187///
188/// person.id,\
189/// GROUPING SETS ((person.age, person.salary),(person.age)),\
190/// ROLLUP(person.state, person.birth_date)
191///
192/// =>
193///
194/// GROUPING SETS (\
195///   (person.id, person.age, person.salary),\
196///   (person.id, person.age, person.salary, person.state),\
197///   (person.id, person.age, person.salary, person.state, person.birth_date),\
198///   (person.id, person.age),\
199///   (person.id, person.age, person.state),\
200///   (person.id, person.age, person.state, person.birth_date)\
201/// )
202pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
203    let has_grouping_set = group_expr
204        .iter()
205        .any(|expr| matches!(expr, Expr::GroupingSet(_)));
206    if !has_grouping_set || group_expr.len() == 1 {
207        return Ok(group_expr);
208    }
209    // Only process mix grouping sets
210    let partial_sets = group_expr
211        .iter()
212        .map(|expr| {
213            let exprs = match expr {
214                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
215                    check_grouping_sets_size_limit(grouping_sets.len())?;
216                    grouping_sets.iter().map(|e| e.iter().collect()).collect()
217                }
218                Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
219                    let grouping_sets = powerset(group_exprs)?;
220                    check_grouping_sets_size_limit(grouping_sets.len())?;
221                    grouping_sets
222                }
223                Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
224                    let size = group_exprs.len();
225                    let slice = group_exprs.as_slice();
226                    check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
227                    (0..(size + 1))
228                        .map(|i| slice[0..i].iter().collect())
229                        .collect()
230                }
231                expr => vec![vec![expr]],
232            };
233            Ok(exprs)
234        })
235        .collect::<Result<Vec<_>>>()?;
236
237    // Cross Join
238    let grouping_sets = partial_sets
239        .into_iter()
240        .map(Ok)
241        .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
242        .transpose()?
243        .map(|e| {
244            e.into_iter()
245                .map(|e| e.into_iter().cloned().collect())
246                .collect()
247        })
248        .unwrap_or_default();
249
250    Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
251        grouping_sets,
252    ))])
253}
254
255/// Find all distinct exprs in a list of group by expressions. If the
256/// first element is a `GroupingSet` expression then it must be the only expr.
257pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
258    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
259        if group_expr.len() > 1 {
260            return plan_err!(
261                "Invalid group by expressions, GroupingSet must be the only expression"
262            );
263        }
264        Ok(grouping_set.distinct_expr())
265    } else {
266        Ok(group_expr
267            .iter()
268            .collect::<IndexSet<_>>()
269            .into_iter()
270            .collect())
271    }
272}
273
274/// Recursively walk an expression tree, collecting the unique set of columns
275/// referenced in the expression
276pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
277    expr.apply(|expr| {
278        match expr {
279            Expr::Column(qc) => {
280                accum.insert(qc.clone());
281            }
282            // Use explicit pattern match instead of a default
283            // implementation, so that in the future if someone adds
284            // new Expr types, they will check here as well
285            // TODO: remove the next line after `Expr::Wildcard` is removed
286            #[expect(deprecated)]
287            Expr::Unnest(_)
288            | Expr::ScalarVariable(_, _)
289            | Expr::Alias(_)
290            | Expr::Literal(_, _)
291            | Expr::BinaryExpr { .. }
292            | Expr::Like { .. }
293            | Expr::SimilarTo { .. }
294            | Expr::Not(_)
295            | Expr::IsNotNull(_)
296            | Expr::IsNull(_)
297            | Expr::IsTrue(_)
298            | Expr::IsFalse(_)
299            | Expr::IsUnknown(_)
300            | Expr::IsNotTrue(_)
301            | Expr::IsNotFalse(_)
302            | Expr::IsNotUnknown(_)
303            | Expr::Negative(_)
304            | Expr::Between { .. }
305            | Expr::Case { .. }
306            | Expr::Cast { .. }
307            | Expr::TryCast { .. }
308            | Expr::ScalarFunction(..)
309            | Expr::WindowFunction { .. }
310            | Expr::AggregateFunction { .. }
311            | Expr::GroupingSet(_)
312            | Expr::InList { .. }
313            | Expr::Exists { .. }
314            | Expr::InSubquery(_)
315            | Expr::SetComparison(_)
316            | Expr::ScalarSubquery(_)
317            | Expr::Wildcard { .. }
318            | Expr::Placeholder(_)
319            | Expr::OuterReferenceColumn { .. }
320            | Expr::HigherOrderFunction(_)
321            | Expr::Lambda(_)
322            | Expr::LambdaVariable(_) => {}
323        }
324        Ok(TreeNodeRecursion::Continue)
325    })
326    .map(|_| ())
327}
328
329/// Find excluded columns in the schema, if any
330/// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]`
331fn get_excluded_columns(
332    opt_exclude: Option<&ExcludeSelectItem>,
333    opt_except: Option<&ExceptSelectItem>,
334    schema: &DFSchema,
335    qualifier: Option<&TableReference>,
336) -> Result<Vec<Column>> {
337    let mut idents = vec![];
338    if let Some(excepts) = opt_except {
339        idents.push(&excepts.first_element);
340        idents.extend(&excepts.additional_elements);
341    }
342    // Declared outside the `if let` so `idents.extend(exclude_owned.iter())`
343    // below can borrow references that outlive the inner scope.
344    let exclude_owned: Vec<Ident>;
345    if let Some(exclude) = opt_exclude {
346        let object_name_to_ident = |name: &ObjectName| -> Result<Ident> {
347            if name.0.len() != 1 {
348                return plan_err!(
349                    "EXCLUDE with multi-part identifiers is not supported: {name}"
350                );
351            }
352            let part = &name.0[0];
353            let Some(ident) = part.as_ident() else {
354                return plan_err!(
355                    "EXCLUDE with non-identifier name part is not supported: {part}"
356                );
357            };
358            Ok(ident.clone())
359        };
360        exclude_owned = match exclude {
361            ExcludeSelectItem::Single(name) => vec![object_name_to_ident(name)?],
362            ExcludeSelectItem::Multiple(names) => names
363                .iter()
364                .map(object_name_to_ident)
365                .collect::<Result<Vec<_>>>()?,
366        };
367        idents.extend(exclude_owned.iter());
368    }
369    // Excluded columns should be unique
370    let n_elem = idents.len();
371    let unique_idents = idents.into_iter().collect::<HashSet<_>>();
372    // If HashSet size, and vector length are different, this means that some of the excluded columns
373    // are not unique. In this case return error.
374    if n_elem != unique_idents.len() {
375        return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
376    }
377
378    let mut result = vec![];
379    for ident in unique_idents.into_iter() {
380        let col_name = ident.value.as_str();
381        let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
382        result.push(Column::from((qualifier, field)));
383    }
384    Ok(result)
385}
386
387/// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip`
388fn get_exprs_except_skipped(
389    schema: &DFSchema,
390    columns_to_skip: &HashSet<Column>,
391) -> Vec<Expr> {
392    if columns_to_skip.is_empty() {
393        schema.iter().map(Expr::from).collect::<Vec<Expr>>()
394    } else {
395        schema
396            .columns()
397            .iter()
398            .filter_map(|c| {
399                if !columns_to_skip.contains(c) {
400                    Some(Expr::Column(c.clone()))
401                } else {
402                    None
403                }
404            })
405            .collect::<Vec<Expr>>()
406    }
407}
408
409/// When a JOIN has a USING clause, the join columns appear in the output
410/// schema once per side (for inner/outer joins) or once total (for semi/anti
411/// joins). An unqualified wildcard should include each USING column only once.
412/// This function returns the duplicate columns that should be excluded.
413fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
414    let output_columns: HashSet<_> = plan.schema().columns().iter().cloned().collect();
415    let mut excluded = HashSet::new();
416    for cols in plan.using_columns()? {
417        // `using_columns()` returns join columns from both sides regardless of
418        // the join type. For semi/anti joins, only one side's columns appear in
419        // the output schema. Filter to output columns so that columns from the
420        // non-output side don't participate in the deduplication process below
421        // and displace real output columns.
422        let mut cols: Vec<_> = cols
423            .into_iter()
424            .filter(|c| output_columns.contains(c))
425            .collect();
426
427        // Sort so we keep the same qualified column, regardless of HashSet
428        // iteration order.
429        cols.sort();
430
431        // Keep only one column per name from the columns set, adding any
432        // duplicates to the excluded set.
433        let mut seen_names = HashSet::new();
434        for col in cols {
435            if seen_names.contains(col.name.as_str()) {
436                excluded.insert(col); // exclude columns with already seen name
437            } else {
438                seen_names.insert(col.name.clone()); // mark column name as seen
439            }
440        }
441    }
442    Ok(excluded)
443}
444
445/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
446pub fn expand_wildcard(
447    schema: &DFSchema,
448    plan: &LogicalPlan,
449    wildcard_options: Option<&WildcardOptions>,
450) -> Result<Vec<Expr>> {
451    let mut columns_to_skip = exclude_using_columns(plan)?;
452    let excluded_columns = if let Some(WildcardOptions {
453        exclude: opt_exclude,
454        except: opt_except,
455        ..
456    }) = wildcard_options
457    {
458        get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
459    } else {
460        vec![]
461    };
462    // Add each excluded `Column` to columns_to_skip
463    columns_to_skip.extend(excluded_columns);
464    Ok(get_exprs_except_skipped(schema, &columns_to_skip))
465}
466
467/// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s.
468pub fn expand_qualified_wildcard(
469    qualifier: &TableReference,
470    schema: &DFSchema,
471    wildcard_options: Option<&WildcardOptions>,
472) -> Result<Vec<Expr>> {
473    let qualified_indices = schema.fields_indices_with_qualified(qualifier);
474    let projected_func_dependencies = schema
475        .functional_dependencies()
476        .project_functional_dependencies(&qualified_indices, qualified_indices.len());
477    let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
478    if fields_with_qualified.is_empty() {
479        return plan_err!("Invalid qualifier {qualifier}");
480    }
481
482    let qualified_schema = Arc::new(Schema::new_with_metadata(
483        fields_with_qualified,
484        schema.metadata().clone(),
485    ));
486    let qualified_dfschema =
487        DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
488            .with_functional_dependencies(projected_func_dependencies)?;
489    let excluded_columns = if let Some(WildcardOptions {
490        exclude: opt_exclude,
491        except: opt_except,
492        ..
493    }) = wildcard_options
494    {
495        get_excluded_columns(
496            opt_exclude.as_ref(),
497            opt_except.as_ref(),
498            schema,
499            Some(qualifier),
500        )?
501    } else {
502        vec![]
503    };
504    // Add each excluded `Column` to columns_to_skip
505    let mut columns_to_skip = HashSet::new();
506    columns_to_skip.extend(excluded_columns);
507    Ok(get_exprs_except_skipped(
508        &qualified_dfschema,
509        &columns_to_skip,
510    ))
511}
512
513/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
514/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
515type WindowSortKey = Vec<(Sort, bool)>;
516
517/// Generate a sort key for a given window expr's partition_by and order_by expr
518pub fn generate_sort_key(
519    partition_by: &[Expr],
520    order_by: &[Sort],
521) -> Result<WindowSortKey> {
522    let normalized_order_by_keys = order_by
523        .iter()
524        .map(|e| {
525            let Sort { expr, .. } = e;
526            Sort::new(expr.clone(), true, false)
527        })
528        .collect::<Vec<_>>();
529
530    let mut final_sort_keys = vec![];
531    let mut is_partition_flag = vec![];
532    partition_by.iter().for_each(|e| {
533        // By default, create sort key with ASC is true and NULLS LAST to be consistent with
534        // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
535        let e = e.clone().sort(true, false);
536        if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
537            let order_by_key = &order_by[pos];
538            if !final_sort_keys.contains(order_by_key) {
539                final_sort_keys.push(order_by_key.clone());
540                is_partition_flag.push(true);
541            }
542        } else if !final_sort_keys.contains(&e) {
543            final_sort_keys.push(e);
544            is_partition_flag.push(true);
545        }
546    });
547
548    order_by.iter().for_each(|e| {
549        if !final_sort_keys.contains(e) {
550            final_sort_keys.push(e.clone());
551            is_partition_flag.push(false);
552        }
553    });
554    let res = final_sort_keys
555        .into_iter()
556        .zip(is_partition_flag)
557        .collect::<Vec<_>>();
558    Ok(res)
559}
560
561/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
562/// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c>
563pub fn compare_sort_expr(
564    sort_expr_a: &Sort,
565    sort_expr_b: &Sort,
566    schema: &DFSchemaRef,
567) -> Ordering {
568    let Sort {
569        expr: expr_a,
570        asc: asc_a,
571        nulls_first: nulls_first_a,
572    } = sort_expr_a;
573
574    let Sort {
575        expr: expr_b,
576        asc: asc_b,
577        nulls_first: nulls_first_b,
578    } = sort_expr_b;
579
580    let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
581    let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
582    for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
583        match idx_a.cmp(idx_b) {
584            Ordering::Less => {
585                return Ordering::Less;
586            }
587            Ordering::Greater => {
588                return Ordering::Greater;
589            }
590            Ordering::Equal => {}
591        }
592    }
593    match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
594        Ordering::Less => return Ordering::Greater,
595        Ordering::Greater => {
596            return Ordering::Less;
597        }
598        Ordering::Equal => {}
599    }
600    match (asc_a, asc_b) {
601        (true, false) => {
602            return Ordering::Greater;
603        }
604        (false, true) => {
605            return Ordering::Less;
606        }
607        _ => {}
608    }
609    match (nulls_first_a, nulls_first_b) {
610        (true, false) => {
611            return Ordering::Less;
612        }
613        (false, true) => {
614            return Ordering::Greater;
615        }
616        _ => {}
617    }
618    Ordering::Equal
619}
620
621/// Group a slice of window expression expr by their order by expressions
622pub fn group_window_expr_by_sort_keys(
623    window_expr: impl IntoIterator<Item = Expr>,
624) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
625    let mut result = vec![];
626    window_expr.into_iter().try_for_each(|expr| match &expr {
627        Expr::WindowFunction(window_fun) => {
628            let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
629            let sort_key = generate_sort_key(partition_by, order_by)?;
630            if let Some((_, values)) = result.iter_mut().find(
631                |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
632            ) {
633                values.push(expr);
634            } else {
635                result.push((sort_key, vec![expr]))
636            }
637            Ok(())
638        }
639        other => internal_err!(
640            "Impossibly got non-window expr {other:?}"
641        ),
642    })?;
643    Ok(result)
644}
645
646/// Collect all deeply nested `Expr::AggregateFunction`.
647/// They are returned in order of occurrence (depth
648/// first), with duplicates omitted.
649pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
650    find_exprs_in_exprs(exprs, &|nested_expr| {
651        matches!(nested_expr, Expr::AggregateFunction { .. })
652    })
653}
654
655/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
656/// (depth first), with duplicates omitted.
657pub fn find_window_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
658    find_exprs_in_exprs(exprs, &|nested_expr| {
659        matches!(nested_expr, Expr::WindowFunction { .. })
660    })
661}
662
663/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence
664/// (depth first), with duplicates omitted.
665pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
666    find_exprs_in_expr(expr, &|nested_expr| {
667        matches!(nested_expr, Expr::OuterReferenceColumn { .. })
668    })
669}
670
671/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
672/// pass the provided test. The returned `Expr`'s are deduplicated and returned
673/// in order of appearance (depth first).
674fn find_exprs_in_exprs<'a, F>(
675    exprs: impl IntoIterator<Item = &'a Expr>,
676    test_fn: &F,
677) -> Vec<Expr>
678where
679    F: Fn(&Expr) -> bool,
680{
681    exprs
682        .into_iter()
683        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
684        .fold(vec![], |mut acc, expr| {
685            if !acc.contains(&expr) {
686                acc.push(expr)
687            }
688            acc
689        })
690}
691
692/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
693/// provided test. The returned `Expr`'s are deduplicated and returned in order
694/// of appearance (depth first).
695fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
696where
697    F: Fn(&Expr) -> bool,
698{
699    let mut exprs = vec![];
700    expr.apply(|expr| {
701        if test_fn(expr) {
702            if !(exprs.contains(expr)) {
703                exprs.push(expr.clone())
704            }
705            // Stop recursing down this expr once we find a match
706            return Ok(TreeNodeRecursion::Jump);
707        }
708
709        Ok(TreeNodeRecursion::Continue)
710    })
711    // pre_visit always returns OK, so this will always too
712    .expect("no way to return error during recursion");
713    exprs
714}
715
716/// Recursively inspect an [`Expr`] and all its children.
717pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
718where
719    F: FnMut(&Expr) -> Result<(), E>,
720{
721    let mut err = Ok(());
722    expr.apply(|expr| {
723        if let Err(e) = f(expr) {
724            // Save the error for later (it may not be a DataFusionError)
725            err = Err(e);
726            Ok(TreeNodeRecursion::Stop)
727        } else {
728            // keep going
729            Ok(TreeNodeRecursion::Continue)
730        }
731    })
732    // The closure always returns OK, so this will always too
733    .expect("no way to return error during recursion");
734
735    err
736}
737
738/// Create schema fields from an expression list, for use in result set schema construction
739///
740/// This function converts a list of expressions into a list of complete schema fields,
741/// making comprehensive determinations about each field's properties including:
742/// - **Data type**: Resolved based on expression type and input schema context
743/// - **Nullability**: Determined by expression-specific nullability rules
744/// - **Metadata**: Computed based on expression type (preserving, merging, or generating new metadata)
745/// - **Table reference scoping**: Establishing proper qualified field references
746///
747/// Each expression is converted to a field by calling [`Expr::to_field`], which performs
748/// the complete field resolution process for all field properties.
749///
750/// # Returns
751///
752/// A `Result` containing a vector of `(Option<TableReference>, Arc<Field>)` tuples,
753/// where each Field contains complete schema information (type, nullability, metadata)
754/// and proper table reference scoping for the corresponding expression.
755pub fn exprlist_to_fields<'a>(
756    exprs: impl IntoIterator<Item = &'a Expr>,
757    plan: &LogicalPlan,
758) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
759    // Look for exact match in plan's output schema
760    let input_schema = plan.schema();
761    exprs
762        .into_iter()
763        .map(|e| e.to_field(input_schema))
764        .collect()
765}
766
767/// Convert an expression into Column expression if it's already provided as input plan.
768///
769/// For example, it rewrites:
770///
771/// ```text
772/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
773/// .project(vec![col("c1"), sum(col("c2"))?
774/// ```
775///
776/// Into:
777///
778/// ```text
779/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
780/// .project(vec![col("c1"), col("SUM(c2)")?
781/// ```
782pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
783    let output_exprs = match input.columnized_output_exprs() {
784        Ok(exprs) if !exprs.is_empty() => exprs,
785        _ => return Ok(e),
786    };
787    let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
788    e.transform_down(|node: Expr| match exprs_map.get(&node) {
789        Some(column) => Ok(Transformed::new(
790            Expr::Column(column.clone()),
791            true,
792            TreeNodeRecursion::Jump,
793        )),
794        None => Ok(Transformed::no(node)),
795    })
796    .data()
797}
798
799/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
800/// appearance (depth first), and may contain duplicates.
801pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
802    exprs
803        .iter()
804        .flat_map(find_columns_referenced_by_expr)
805        .map(Expr::Column)
806        .collect()
807}
808
809pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
810    let mut exprs = vec![];
811    e.apply(|expr| {
812        if let Expr::Column(c) = expr {
813            exprs.push(c.clone())
814        }
815        Ok(TreeNodeRecursion::Continue)
816    })
817    // As the closure always returns Ok, this "can't" error
818    .expect("Unexpected error");
819    exprs
820}
821
822/// Convert any `Expr` to an `Expr::Column`.
823pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
824    match expr {
825        Expr::Column(col) => {
826            let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
827            Ok(Expr::from(Column::from((qualifier, field))))
828        }
829        _ => Ok(Expr::Column(Column::from_name(
830            expr.schema_name().to_string(),
831        ))),
832    }
833}
834
835/// Recursively walk an expression tree, collecting the column indexes
836/// referenced in the expression
837pub(crate) fn find_column_indexes_referenced_by_expr(
838    e: &Expr,
839    schema: &DFSchemaRef,
840) -> Vec<usize> {
841    let mut indexes = vec![];
842    e.apply(|expr| {
843        match expr {
844            Expr::Column(qc) => {
845                if let Ok(idx) = schema.index_of_column(qc) {
846                    indexes.push(idx);
847                }
848            }
849            Expr::Literal(_, _) => {
850                indexes.push(usize::MAX);
851            }
852            _ => {}
853        }
854        Ok(TreeNodeRecursion::Continue)
855    })
856    .unwrap();
857    indexes
858}
859
860/// Can this data type be used in hash join equal conditions??
861/// Data types here come from function 'equal_rows', if more data types are supported
862/// in create_hashes, add those data types here to generate join logical plan.
863pub fn can_hash(data_type: &DataType) -> bool {
864    match data_type {
865        DataType::Null => true,
866        DataType::Boolean => true,
867        DataType::Int8 => true,
868        DataType::Int16 => true,
869        DataType::Int32 => true,
870        DataType::Int64 => true,
871        DataType::UInt8 => true,
872        DataType::UInt16 => true,
873        DataType::UInt32 => true,
874        DataType::UInt64 => true,
875        DataType::Float16 => true,
876        DataType::Float32 => true,
877        DataType::Float64 => true,
878        DataType::Decimal32(_, _) => true,
879        DataType::Decimal64(_, _) => true,
880        DataType::Decimal128(_, _) => true,
881        DataType::Decimal256(_, _) => true,
882        DataType::Timestamp(_, _) => true,
883        DataType::Utf8 => true,
884        DataType::LargeUtf8 => true,
885        DataType::Utf8View => true,
886        DataType::Binary => true,
887        DataType::LargeBinary => true,
888        DataType::BinaryView => true,
889        DataType::Date32 => true,
890        DataType::Date64 => true,
891        DataType::Time32(_) => true,
892        DataType::Time64(_) => true,
893        DataType::Duration(_) => true,
894        DataType::Interval(_) => true,
895        DataType::FixedSizeBinary(_) => true,
896        DataType::Dictionary(key_type, value_type) => {
897            DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
898        }
899        DataType::List(value_type) => can_hash(value_type.data_type()),
900        DataType::LargeList(value_type) => can_hash(value_type.data_type()),
901        DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
902        DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
903        DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
904
905        DataType::ListView(_)
906        | DataType::LargeListView(_)
907        | DataType::Union(_, _)
908        | DataType::RunEndEncoded(_, _) => false,
909    }
910}
911
912/// Check whether all columns are from the schema.
913pub fn check_all_columns_from_schema(
914    columns: &HashSet<&Column>,
915    schema: &DFSchema,
916) -> Result<bool> {
917    for col in columns.iter() {
918        let exist = schema.is_column_from_schema(col);
919        if !exist {
920            return Ok(false);
921        }
922    }
923
924    Ok(true)
925}
926
927/// Give two sides of the equijoin predicate, return a valid join key pair.
928/// If there is no valid join key pair, return None.
929///
930/// A valid join means:
931/// 1. All referenced column of the left side is from the left schema, and
932///    all referenced column of the right side is from the right schema.
933/// 2. Or opposite. All referenced column of the left side is from the right schema,
934///    and the right side is from the left schema.
935pub fn find_valid_equijoin_key_pair(
936    left_key: &Expr,
937    right_key: &Expr,
938    left_schema: &DFSchema,
939    right_schema: &DFSchema,
940) -> Result<Option<(Expr, Expr)>> {
941    let left_using_columns = left_key.column_refs();
942    let right_using_columns = right_key.column_refs();
943
944    // Conditions like a = 10, will be added to non-equijoin.
945    if left_using_columns.is_empty() || right_using_columns.is_empty() {
946        return Ok(None);
947    }
948
949    if check_all_columns_from_schema(&left_using_columns, left_schema)?
950        && check_all_columns_from_schema(&right_using_columns, right_schema)?
951    {
952        return Ok(Some((left_key.clone(), right_key.clone())));
953    } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
954        && check_all_columns_from_schema(&left_using_columns, right_schema)?
955    {
956        return Ok(Some((right_key.clone(), left_key.clone())));
957    }
958
959    Ok(None)
960}
961
962/// Creates a detailed error message for a function with wrong signature.
963///
964/// For example, a query like `select round(3.14, 1.1);` would yield:
965/// ```text
966/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
967///     Candidate functions:
968///     round(Float64, Int64)
969///     round(Float32, Int64)
970///     round(Float64)
971///     round(Float32)
972/// ```
973#[expect(clippy::needless_pass_by_value)]
974#[deprecated(since = "53.0.0", note = "Internal function")]
975pub fn generate_signature_error_msg(
976    func_name: &str,
977    func_signature: Signature,
978    input_expr_types: &[DataType],
979) -> String {
980    let candidate_signatures = func_signature
981        .type_signature
982        .to_string_repr_with_names(func_signature.parameter_names.as_deref())
983        .iter()
984        .map(|args_str| format!("\t{func_name}({args_str})"))
985        .collect::<Vec<String>>()
986        .join("\n");
987
988    format!(
989        "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
990        func_name,
991        TypeSignature::join_types(input_expr_types, ", "),
992        candidate_signatures
993    )
994}
995
996/// Creates a detailed error message for a function with wrong signature.
997///
998/// For example, a query like `select round(3.14, 1.1);` would yield:
999/// ```text
1000/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
1001///     Candidate functions:
1002///     round(Float64, Int64)
1003///     round(Float32, Int64)
1004///     round(Float64)
1005///     round(Float32)
1006/// ```
1007pub(crate) fn generate_signature_error_message(
1008    func_name: &str,
1009    func_signature: &Signature,
1010    input_expr_types: &[DataType],
1011) -> String {
1012    #[expect(deprecated)]
1013    generate_signature_error_msg(func_name, func_signature.clone(), input_expr_types)
1014}
1015
1016/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1017///
1018/// See [`split_conjunction_owned`] for more details and an example.
1019pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
1020    split_conjunction_impl(expr, vec![])
1021}
1022
1023fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
1024    match expr {
1025        Expr::BinaryExpr(BinaryExpr {
1026            right,
1027            op: Operator::And,
1028            left,
1029        }) => {
1030            let exprs = split_conjunction_impl(left, exprs);
1031            split_conjunction_impl(right, exprs)
1032        }
1033        Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
1034        other => {
1035            exprs.push(other);
1036            exprs
1037        }
1038    }
1039}
1040
1041/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1042///
1043/// See [`split_conjunction_owned`] for more details and an example.
1044pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
1045    let mut stack = vec![expr];
1046    std::iter::from_fn(move || {
1047        while let Some(expr) = stack.pop() {
1048            match expr {
1049                Expr::BinaryExpr(BinaryExpr {
1050                    right,
1051                    op: Operator::And,
1052                    left,
1053                }) => {
1054                    stack.push(right);
1055                    stack.push(left);
1056                }
1057                Expr::Alias(Alias { expr, .. }) => stack.push(expr),
1058                other => return Some(other),
1059            }
1060        }
1061        None
1062    })
1063}
1064
1065/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1066///
1067/// See [`split_conjunction_owned`] for more details and an example.
1068pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
1069    let mut stack = vec![expr];
1070    std::iter::from_fn(move || {
1071        while let Some(expr) = stack.pop() {
1072            match expr {
1073                Expr::BinaryExpr(BinaryExpr {
1074                    right,
1075                    op: Operator::And,
1076                    left,
1077                }) => {
1078                    stack.push(*right);
1079                    stack.push(*left);
1080                }
1081                Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
1082                other => return Some(other),
1083            }
1084        }
1085        None
1086    })
1087}
1088
1089/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1090///
1091/// This is often used to "split" filter expressions such as `col1 = 5
1092/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1093///
1094/// # Example
1095/// ```
1096/// # use datafusion_expr::{col, lit};
1097/// # use datafusion_expr::utils::split_conjunction_owned;
1098/// // a=1 AND b=2
1099/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1100///
1101/// // [a=1, b=2]
1102/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1103///
1104/// // use split_conjunction_owned to split them
1105/// assert_eq!(split_conjunction_owned(expr), split);
1106/// ```
1107pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1108    split_binary_owned(expr, Operator::And)
1109}
1110
1111/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1112///
1113/// This is often used to "split" expressions such as `col1 = 5
1114/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1115///
1116/// # Example
1117/// ```
1118/// # use datafusion_expr::{col, lit, Operator};
1119/// # use datafusion_expr::utils::split_binary_owned;
1120/// # use std::ops::Add;
1121/// // a=1 + b=2
1122/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
1123///
1124/// // [a=1, b=2]
1125/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1126///
1127/// // use split_binary_owned to split them
1128/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
1129/// ```
1130pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1131    split_binary_owned_impl(expr, op, vec![])
1132}
1133
1134fn split_binary_owned_impl(
1135    expr: Expr,
1136    operator: Operator,
1137    mut exprs: Vec<Expr>,
1138) -> Vec<Expr> {
1139    match expr {
1140        Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1141            let exprs = split_binary_owned_impl(*left, operator, exprs);
1142            split_binary_owned_impl(*right, operator, exprs)
1143        }
1144        Expr::Alias(Alias { expr, .. }) => {
1145            split_binary_owned_impl(*expr, operator, exprs)
1146        }
1147        other => {
1148            exprs.push(other);
1149            exprs
1150        }
1151    }
1152}
1153
1154/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1155///
1156/// See [`split_binary_owned`] for more details and an example.
1157pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1158    split_binary_impl(expr, op, vec![])
1159}
1160
1161fn split_binary_impl<'a>(
1162    expr: &'a Expr,
1163    operator: Operator,
1164    mut exprs: Vec<&'a Expr>,
1165) -> Vec<&'a Expr> {
1166    match expr {
1167        Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1168            let exprs = split_binary_impl(left, operator, exprs);
1169            split_binary_impl(right, operator, exprs)
1170        }
1171        Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1172        other => {
1173            exprs.push(other);
1174            exprs
1175        }
1176    }
1177}
1178
1179/// Combines an array of filter expressions into a single filter
1180/// expression consisting of the input filter expressions joined with
1181/// logical AND.
1182///
1183/// Returns None if the filters array is empty.
1184///
1185/// # Example
1186/// ```
1187/// # use datafusion_expr::{col, lit};
1188/// # use datafusion_expr::utils::conjunction;
1189/// // a=1 AND b=2
1190/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1191///
1192/// // [a=1, b=2]
1193/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1194///
1195/// // use conjunction to join them together with `AND`
1196/// assert_eq!(conjunction(split), Some(expr));
1197/// ```
1198pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1199    filters.into_iter().reduce(Expr::and)
1200}
1201
1202/// Combines an array of filter expressions into a single filter
1203/// expression consisting of the input filter expressions joined with
1204/// logical OR.
1205///
1206/// Returns None if the filters array is empty.
1207///
1208/// # Example
1209/// ```
1210/// # use datafusion_expr::{col, lit};
1211/// # use datafusion_expr::utils::disjunction;
1212/// // a=1 OR b=2
1213/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
1214///
1215/// // [a=1, b=2]
1216/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1217///
1218/// // use disjunction to join them together with `OR`
1219/// assert_eq!(disjunction(split), Some(expr));
1220/// ```
1221pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1222    filters.into_iter().reduce(Expr::or)
1223}
1224
1225/// Returns a new [LogicalPlan] that filters the output of  `plan` with a
1226/// [LogicalPlan::Filter] with all `predicates` ANDed.
1227///
1228/// # Example
1229/// Before:
1230/// ```text
1231/// plan
1232/// ```
1233///
1234/// After:
1235/// ```text
1236/// Filter(predicate)
1237///   plan
1238/// ```
1239pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1240    // reduce filters to a single filter with an AND
1241    let predicate = predicates
1242        .iter()
1243        .skip(1)
1244        .fold(predicates[0].clone(), |acc, predicate| {
1245            and(acc, (*predicate).to_owned())
1246        });
1247
1248    Ok(LogicalPlan::Filter(Filter::try_new(
1249        predicate,
1250        Arc::new(plan),
1251    )?))
1252}
1253
1254/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and
1255/// one not in the subquery (closed upon from outer scope)
1256///
1257/// # Arguments
1258///
1259/// * `exprs` - List of expressions that may or may not be joins
1260///
1261/// # Return value
1262///
1263/// Tuple of (expressions containing joins, remaining non-join expressions)
1264pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1265    let mut joins = vec![];
1266    let mut others = vec![];
1267    for filter in exprs.into_iter() {
1268        // If the expression contains correlated predicates, add it to join filters
1269        if filter.contains_outer() {
1270            if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1271            {
1272                joins.push(strip_outer_reference((*filter).clone()));
1273            }
1274        } else {
1275            others.push((*filter).clone());
1276        }
1277    }
1278
1279    Ok((joins, others))
1280}
1281
1282/// Returns the first (and only) element in a slice, or an error
1283///
1284/// # Arguments
1285///
1286/// * `slice` - The slice to extract from
1287///
1288/// # Return value
1289///
1290/// The first element, or an error
1291pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1292    match slice {
1293        [it] => Ok(it),
1294        [] => plan_err!("No items found!"),
1295        _ => plan_err!("More than one item found!"),
1296    }
1297}
1298
1299/// merge inputs schema into a single schema.
1300///
1301/// This function merges schemas from multiple logical plan inputs using [`DFSchema::merge`].
1302/// Refer to that documentation for details on precedence and metadata handling.
1303pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1304    if inputs.len() == 1 {
1305        inputs[0].schema().as_ref().clone()
1306    } else {
1307        inputs.iter().map(|input| input.schema()).fold(
1308            DFSchema::empty(),
1309            |mut lhs, rhs| {
1310                lhs.merge(rhs);
1311                lhs
1312            },
1313        )
1314    }
1315}
1316
1317/// Build state name. State is the intermediate state of the aggregate function.
1318pub fn format_state_name(name: &str, state_name: &str) -> String {
1319    format!("{name}[{state_name}]")
1320}
1321
1322/// Determine the set of [`Column`]s produced by the subquery.
1323pub fn collect_subquery_cols(
1324    exprs: &[Expr],
1325    subquery_schema: &DFSchema,
1326) -> Result<BTreeSet<Column>> {
1327    exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1328        let mut using_cols: Vec<Column> = vec![];
1329        for col in expr.column_refs().into_iter() {
1330            if subquery_schema.has_column(col) {
1331                using_cols.push(col.clone());
1332            }
1333        }
1334
1335        cols.extend(using_cols);
1336        Result::<_>::Ok(cols)
1337    })
1338}
1339
1340#[cfg(test)]
1341mod tests {
1342    use super::*;
1343    use crate::{
1344        Cast, ExprFunctionExt, WindowFunctionDefinition, col, cube,
1345        expr::WindowFunction,
1346        expr_vec_fmt, grouping_set, lit, rollup,
1347        test::function_stub::{max_udaf, min_udaf, sum_udaf},
1348    };
1349    use arrow::datatypes::{UnionFields, UnionMode};
1350    use datafusion_expr_common::signature::Volatility;
1351
1352    #[test]
1353    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1354        let result = group_window_expr_by_sort_keys(vec![])?;
1355        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1356        assert_eq!(expected, result);
1357        Ok(())
1358    }
1359
1360    #[test]
1361    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1362        let max1 = Expr::from(WindowFunction::new(
1363            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1364            vec![col("name")],
1365        ));
1366        let max2 = Expr::from(WindowFunction::new(
1367            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1368            vec![col("name")],
1369        ));
1370        let min3 = Expr::from(WindowFunction::new(
1371            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1372            vec![col("name")],
1373        ));
1374        let sum4 = Expr::from(WindowFunction::new(
1375            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1376            vec![col("age")],
1377        ));
1378        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1379        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1380        let key = vec![];
1381        let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1382            vec![(key, vec![max1, max2, min3, sum4])];
1383        assert_eq!(expected, result);
1384        Ok(())
1385    }
1386
1387    #[test]
1388    fn test_group_window_expr_by_sort_keys() -> Result<()> {
1389        let age_asc = Sort::new(col("age"), true, true);
1390        let name_desc = Sort::new(col("name"), false, true);
1391        let created_at_desc = Sort::new(col("created_at"), false, true);
1392        let max1 = Expr::from(WindowFunction::new(
1393            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1394            vec![col("name")],
1395        ))
1396        .order_by(vec![age_asc.clone(), name_desc.clone()])
1397        .build()
1398        .unwrap();
1399        let max2 = Expr::from(WindowFunction::new(
1400            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1401            vec![col("name")],
1402        ));
1403        let min3 = Expr::from(WindowFunction::new(
1404            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1405            vec![col("name")],
1406        ))
1407        .order_by(vec![age_asc.clone(), name_desc.clone()])
1408        .build()
1409        .unwrap();
1410        let sum4 = Expr::from(WindowFunction::new(
1411            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1412            vec![col("age")],
1413        ))
1414        .order_by(vec![
1415            name_desc.clone(),
1416            age_asc.clone(),
1417            created_at_desc.clone(),
1418        ])
1419        .build()
1420        .unwrap();
1421        // FIXME use as_ref
1422        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1423        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1424
1425        let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1426        let key2 = vec![];
1427        let key3 = vec![
1428            (name_desc, false),
1429            (age_asc, false),
1430            (created_at_desc, false),
1431        ];
1432
1433        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1434            (key1, vec![max1, min3]),
1435            (key2, vec![max2]),
1436            (key3, vec![sum4]),
1437        ];
1438        assert_eq!(expected, result);
1439        Ok(())
1440    }
1441
1442    #[test]
1443    fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1444        let asc_or_desc = [true, false];
1445        let nulls_first_or_last = [true, false];
1446        let partition_by = &[col("age"), col("name"), col("created_at")];
1447        for asc_ in asc_or_desc {
1448            for nulls_first_ in nulls_first_or_last {
1449                let order_by = &[
1450                    Sort {
1451                        expr: col("age"),
1452                        asc: asc_,
1453                        nulls_first: nulls_first_,
1454                    },
1455                    Sort {
1456                        expr: col("name"),
1457                        asc: asc_,
1458                        nulls_first: nulls_first_,
1459                    },
1460                ];
1461
1462                let expected = vec![
1463                    (
1464                        Sort {
1465                            expr: col("age"),
1466                            asc: asc_,
1467                            nulls_first: nulls_first_,
1468                        },
1469                        true,
1470                    ),
1471                    (
1472                        Sort {
1473                            expr: col("name"),
1474                            asc: asc_,
1475                            nulls_first: nulls_first_,
1476                        },
1477                        true,
1478                    ),
1479                    (
1480                        Sort {
1481                            expr: col("created_at"),
1482                            asc: true,
1483                            nulls_first: false,
1484                        },
1485                        true,
1486                    ),
1487                ];
1488                let result = generate_sort_key(partition_by, order_by)?;
1489                assert_eq!(expected, result);
1490            }
1491        }
1492        Ok(())
1493    }
1494
1495    #[test]
1496    fn test_enumerate_grouping_sets() -> Result<()> {
1497        let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1498        let simple_col = col("simple_col");
1499        let cube = cube(multi_cols.clone());
1500        let rollup = rollup(multi_cols.clone());
1501        let grouping_set = grouping_set(vec![multi_cols]);
1502
1503        // 1. col
1504        let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1505        let result = format!("[{}]", expr_vec_fmt!(sets));
1506        assert_eq!("[simple_col]", &result);
1507
1508        // 2. cube
1509        let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1510        let result = format!("[{}]", expr_vec_fmt!(sets));
1511        assert_eq!("[CUBE (col1, col2, col3)]", &result);
1512
1513        // 3. rollup
1514        let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1515        let result = format!("[{}]", expr_vec_fmt!(sets));
1516        assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1517
1518        // 4. col + cube
1519        let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1520        let result = format!("[{}]", expr_vec_fmt!(sets));
1521        assert_eq!(
1522            "[GROUPING SETS (\
1523            (simple_col), \
1524            (simple_col, col1), \
1525            (simple_col, col2), \
1526            (simple_col, col1, col2), \
1527            (simple_col, col3), \
1528            (simple_col, col1, col3), \
1529            (simple_col, col2, col3), \
1530            (simple_col, col1, col2, col3))]",
1531            &result
1532        );
1533
1534        // 5. col + rollup
1535        let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1536        let result = format!("[{}]", expr_vec_fmt!(sets));
1537        assert_eq!(
1538            "[GROUPING SETS (\
1539            (simple_col), \
1540            (simple_col, col1), \
1541            (simple_col, col1, col2), \
1542            (simple_col, col1, col2, col3))]",
1543            &result
1544        );
1545
1546        // 6. col + grouping_set
1547        let sets =
1548            enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1549        let result = format!("[{}]", expr_vec_fmt!(sets));
1550        assert_eq!(
1551            "[GROUPING SETS (\
1552            (simple_col, col1, col2, col3))]",
1553            &result
1554        );
1555
1556        // 7. col + grouping_set + rollup
1557        let sets = enumerate_grouping_sets(vec![
1558            simple_col.clone(),
1559            grouping_set,
1560            rollup.clone(),
1561        ])?;
1562        let result = format!("[{}]", expr_vec_fmt!(sets));
1563        assert_eq!(
1564            "[GROUPING SETS (\
1565            (simple_col, col1, col2, col3), \
1566            (simple_col, col1, col2, col3, col1), \
1567            (simple_col, col1, col2, col3, col1, col2), \
1568            (simple_col, col1, col2, col3, col1, col2, col3))]",
1569            &result
1570        );
1571
1572        // 8. col + cube + rollup
1573        let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1574        let result = format!("[{}]", expr_vec_fmt!(sets));
1575        assert_eq!(
1576            "[GROUPING SETS (\
1577            (simple_col), \
1578            (simple_col, col1), \
1579            (simple_col, col1, col2), \
1580            (simple_col, col1, col2, col3), \
1581            (simple_col, col1), \
1582            (simple_col, col1, col1), \
1583            (simple_col, col1, col1, col2), \
1584            (simple_col, col1, col1, col2, col3), \
1585            (simple_col, col2), \
1586            (simple_col, col2, col1), \
1587            (simple_col, col2, col1, col2), \
1588            (simple_col, col2, col1, col2, col3), \
1589            (simple_col, col1, col2), \
1590            (simple_col, col1, col2, col1), \
1591            (simple_col, col1, col2, col1, col2), \
1592            (simple_col, col1, col2, col1, col2, col3), \
1593            (simple_col, col3), \
1594            (simple_col, col3, col1), \
1595            (simple_col, col3, col1, col2), \
1596            (simple_col, col3, col1, col2, col3), \
1597            (simple_col, col1, col3), \
1598            (simple_col, col1, col3, col1), \
1599            (simple_col, col1, col3, col1, col2), \
1600            (simple_col, col1, col3, col1, col2, col3), \
1601            (simple_col, col2, col3), \
1602            (simple_col, col2, col3, col1), \
1603            (simple_col, col2, col3, col1, col2), \
1604            (simple_col, col2, col3, col1, col2, col3), \
1605            (simple_col, col1, col2, col3), \
1606            (simple_col, col1, col2, col3, col1), \
1607            (simple_col, col1, col2, col3, col1, col2), \
1608            (simple_col, col1, col2, col3, col1, col2, col3))]",
1609            &result
1610        );
1611
1612        Ok(())
1613    }
1614    #[test]
1615    fn test_split_conjunction() {
1616        let expr = col("a");
1617        let result = split_conjunction(&expr);
1618        assert_eq!(result, vec![&expr]);
1619    }
1620
1621    #[test]
1622    fn test_split_conjunction_two() {
1623        let expr = col("a").eq(lit(5)).and(col("b"));
1624        let expr1 = col("a").eq(lit(5));
1625        let expr2 = col("b");
1626
1627        let result = split_conjunction(&expr);
1628        assert_eq!(result, vec![&expr1, &expr2]);
1629    }
1630
1631    #[test]
1632    fn test_split_conjunction_alias() {
1633        let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1634        let expr1 = col("a").eq(lit(5));
1635        let expr2 = col("b"); // has no alias
1636
1637        let result = split_conjunction(&expr);
1638        assert_eq!(result, vec![&expr1, &expr2]);
1639    }
1640
1641    #[test]
1642    fn test_split_conjunction_or() {
1643        let expr = col("a").eq(lit(5)).or(col("b"));
1644        let result = split_conjunction(&expr);
1645        assert_eq!(result, vec![&expr]);
1646    }
1647
1648    #[test]
1649    fn test_split_binary_owned() {
1650        let expr = col("a");
1651        assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1652    }
1653
1654    #[test]
1655    fn test_split_binary_owned_two() {
1656        assert_eq!(
1657            split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1658            vec![col("a").eq(lit(5)), col("b")]
1659        );
1660    }
1661
1662    #[test]
1663    fn test_split_binary_owned_different_op() {
1664        let expr = col("a").eq(lit(5)).or(col("b"));
1665        assert_eq!(
1666            // expr is connected by OR, but pass in AND
1667            split_binary_owned(expr.clone(), Operator::And),
1668            vec![expr]
1669        );
1670    }
1671
1672    #[test]
1673    fn test_split_conjunction_owned() {
1674        let expr = col("a");
1675        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1676    }
1677
1678    #[test]
1679    fn test_split_conjunction_owned_two() {
1680        assert_eq!(
1681            split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1682            vec![col("a").eq(lit(5)), col("b")]
1683        );
1684    }
1685
1686    #[test]
1687    fn test_split_conjunction_owned_alias() {
1688        assert_eq!(
1689            split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1690            vec![
1691                col("a").eq(lit(5)),
1692                // no alias on b
1693                col("b"),
1694            ]
1695        );
1696    }
1697
1698    #[test]
1699    fn test_conjunction_empty() {
1700        assert_eq!(conjunction(vec![]), None);
1701    }
1702
1703    #[test]
1704    fn test_conjunction() {
1705        // `[A, B, C]`
1706        let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1707
1708        // --> `(A AND B) AND C`
1709        assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1710
1711        // which is different than `A AND (B AND C)`
1712        assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1713    }
1714
1715    #[test]
1716    fn test_disjunction_empty() {
1717        assert_eq!(disjunction(vec![]), None);
1718    }
1719
1720    #[test]
1721    fn test_disjunction() {
1722        // `[A, B, C]`
1723        let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1724
1725        // --> `(A OR B) OR C`
1726        assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1727
1728        // which is different than `A OR (B OR C)`
1729        assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1730    }
1731
1732    #[test]
1733    fn test_split_conjunction_owned_or() {
1734        let expr = col("a").eq(lit(5)).or(col("b"));
1735        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1736    }
1737
1738    #[test]
1739    fn test_collect_expr() -> Result<()> {
1740        let mut accum: HashSet<Column> = HashSet::new();
1741        expr_to_columns(
1742            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1743            &mut accum,
1744        )?;
1745        expr_to_columns(
1746            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1747            &mut accum,
1748        )?;
1749        assert_eq!(1, accum.len());
1750        assert!(accum.contains(&Column::from_name("a")));
1751        Ok(())
1752    }
1753
1754    #[test]
1755    fn test_can_hash() {
1756        let union_fields: UnionFields = [
1757            (0, Arc::new(Field::new("A", DataType::Int32, true))),
1758            (1, Arc::new(Field::new("B", DataType::Float64, true))),
1759        ]
1760        .into_iter()
1761        .collect();
1762
1763        let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1764        assert!(!can_hash(&union_type));
1765
1766        let list_union_type =
1767            DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1768        assert!(!can_hash(&list_union_type));
1769    }
1770
1771    #[test]
1772    fn test_generate_signature_error_msg_with_parameter_names() {
1773        let sig = Signature::one_of(
1774            vec![
1775                TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
1776                TypeSignature::Exact(vec![
1777                    DataType::Utf8,
1778                    DataType::Int64,
1779                    DataType::Int64,
1780                ]),
1781            ],
1782            Volatility::Immutable,
1783        )
1784        .with_parameter_names(vec![
1785            "str".to_string(),
1786            "start_pos".to_string(),
1787            "length".to_string(),
1788        ])
1789        .expect("valid parameter names");
1790
1791        // Generate error message with only 1 argument provided
1792        let error_msg =
1793            generate_signature_error_message("substr", &sig, &[DataType::Utf8]);
1794
1795        assert!(
1796            error_msg.contains("str: Utf8, start_pos: Int64"),
1797            "Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}"
1798        );
1799        assert!(
1800            error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"),
1801            "Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}"
1802        );
1803    }
1804
1805    #[test]
1806    fn test_generate_signature_error_msg_without_parameter_names() {
1807        let sig = Signature::one_of(
1808            vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1809            Volatility::Immutable,
1810        );
1811
1812        let error_msg =
1813            generate_signature_error_message("my_func", &sig, &[DataType::Int32]);
1814
1815        assert!(
1816            error_msg.contains("Any, Any"),
1817            "Expected 'Any, Any' without parameter names, got: {error_msg}"
1818        );
1819    }
1820
1821    #[test]
1822    fn test_signature_error_msg_exact() {
1823        use insta::assert_snapshot;
1824
1825        let sig = Signature::one_of(
1826            vec![
1827                TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]),
1828                TypeSignature::Exact(vec![DataType::Float32, DataType::Int64]),
1829                TypeSignature::Exact(vec![DataType::Float64]),
1830                TypeSignature::Exact(vec![DataType::Float32]),
1831            ],
1832            Volatility::Immutable,
1833        );
1834        let msg = generate_signature_error_message(
1835            "round",
1836            &sig,
1837            &[DataType::Float64, DataType::Float64],
1838        );
1839        assert_snapshot!(msg, @r"
1840        No function matches the given name and argument types 'round(Float64, Float64)'. You might need to add explicit type casts.
1841        	Candidate functions:
1842        	round(Float64, Int64)
1843        	round(Float32, Int64)
1844        	round(Float64)
1845        	round(Float32)
1846        ");
1847    }
1848
1849    #[test]
1850    fn test_signature_error_msg_coercible() {
1851        use datafusion_common::types::NativeType;
1852        use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
1853        use insta::assert_snapshot;
1854
1855        let sig = Signature::coercible(
1856            vec![
1857                Coercion::new_implicit(
1858                    TypeSignatureClass::Native(
1859                        datafusion_common::types::logical_float64(),
1860                    ),
1861                    vec![TypeSignatureClass::Numeric],
1862                    NativeType::Float64,
1863                ),
1864                Coercion::new_implicit(
1865                    TypeSignatureClass::Native(datafusion_common::types::logical_int64()),
1866                    vec![TypeSignatureClass::Integer],
1867                    NativeType::Int64,
1868                ),
1869            ],
1870            Volatility::Immutable,
1871        );
1872        let msg = generate_signature_error_message(
1873            "round",
1874            &sig,
1875            &[DataType::Utf8, DataType::Utf8],
1876        );
1877        assert_snapshot!(msg, @r"
1878        No function matches the given name and argument types 'round(Utf8, Utf8)'. You might need to add explicit type casts.
1879        	Candidate functions:
1880        	round(Float64, Int64)
1881        ");
1882    }
1883
1884    #[test]
1885    fn test_signature_error_msg_with_names_coercible() {
1886        use datafusion_common::types::NativeType;
1887        use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
1888        use insta::assert_snapshot;
1889
1890        let sig = Signature::coercible(
1891            vec![
1892                Coercion::new_exact(TypeSignatureClass::Native(
1893                    datafusion_common::types::logical_string(),
1894                )),
1895                Coercion::new_exact(TypeSignatureClass::Native(
1896                    datafusion_common::types::logical_int64(),
1897                )),
1898                Coercion::new_implicit(
1899                    TypeSignatureClass::Native(datafusion_common::types::logical_int64()),
1900                    vec![TypeSignatureClass::Integer],
1901                    NativeType::Int64,
1902                ),
1903            ],
1904            Volatility::Immutable,
1905        )
1906        .with_parameter_names(vec![
1907            "string".to_string(),
1908            "start_pos".to_string(),
1909            "length".to_string(),
1910        ])
1911        .expect("valid parameter names");
1912
1913        let msg = generate_signature_error_message("substr", &sig, &[DataType::Int32]);
1914        assert_snapshot!(msg, @r"
1915        No function matches the given name and argument types 'substr(Int32)'. You might need to add explicit type casts.
1916        	Candidate functions:
1917        	substr(string: String, start_pos: Int64, length: Int64)
1918        ");
1919    }
1920}