Skip to main content

datafusion_expr/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Expression utilities
19
20use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27    BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, and,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33    Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37    Column, DFSchema, DFSchemaRef, HashMap, Result, TableReference, internal_err,
38    plan_err,
39};
40
41#[cfg(not(feature = "sql"))]
42use crate::expr::{ExceptSelectItem, ExcludeSelectItem};
43use indexmap::IndexSet;
44#[cfg(feature = "sql")]
45use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
46
47pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
48
49///  The value to which `COUNT(*)` is expanded to in
50///  `COUNT(<constant>)` expressions
51pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
52
53/// Count the number of distinct exprs in a list of group by expressions. If the
54/// first element is a `GroupingSet` expression then it must be the only expr.
55pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
56    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
57        if group_expr.len() > 1 {
58            return plan_err!(
59                "Invalid group by expressions, GroupingSet must be the only expression"
60            );
61        }
62        // Groupings sets have an additional integral column for the grouping id
63        Ok(grouping_set.distinct_expr().len() + 1)
64    } else {
65        grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
66    }
67}
68
69/// Internal helper that generates indices for powerset subsets using bitset iteration.
70/// Returns an iterator of index vectors, where each vector contains the indices
71/// of elements to include in that subset.
72fn powerset_indices(len: usize) -> impl Iterator<Item = Vec<usize>> {
73    (0..(1 << len)).map(move |mask| {
74        let mut indices = vec![];
75        let mut bitset = mask;
76        while bitset > 0 {
77            let rightmost: u64 = bitset & !(bitset - 1);
78            let idx = rightmost.trailing_zeros() as usize;
79            indices.push(idx);
80            bitset &= bitset - 1;
81        }
82        indices
83    })
84}
85
86/// The [power set] (or powerset) of a set S is the set of all subsets of S, \
87/// including the empty set and S itself.
88///
89/// Example:
90///
91/// If S is the set {x, y, z}, then all the subsets of S are \
92///  {} \
93///  {x} \
94///  {y} \
95///  {z} \
96///  {x, y} \
97///  {x, z} \
98///  {y, z} \
99///  {x, y, z} \
100///  and hence the power set of S is {{}, {x}, {y}, {z}, {x, y}, {x, z}, {y, z}, {x, y, z}}.
101///
102/// [power set]: https://en.wikipedia.org/wiki/Power_set
103pub fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>> {
104    if slice.len() >= 64 {
105        return plan_err!("The size of the set must be less than 64");
106    }
107
108    Ok(powerset_indices(slice.len())
109        .map(|indices| indices.iter().map(|&idx| &slice[idx]).collect())
110        .collect())
111}
112
113/// check the number of expressions contained in the grouping_set
114fn check_grouping_set_size_limit(size: usize) -> Result<()> {
115    let max_grouping_set_size = 65535;
116    if size > max_grouping_set_size {
117        return plan_err!(
118            "The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"
119        );
120    }
121
122    Ok(())
123}
124
125/// check the number of grouping_set contained in the grouping sets
126fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
127    let max_grouping_sets_size = 4096;
128    if size > max_grouping_sets_size {
129        return plan_err!(
130            "The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"
131        );
132    }
133
134    Ok(())
135}
136
137/// Merge two grouping_set
138///
139/// # Example
140/// ```text
141/// (A, B), (C, D) -> (A, B, C, D)
142/// ```
143///
144/// # Error
145/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
146///
147/// [`DataFusionError`]: datafusion_common::DataFusionError
148fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
149    check_grouping_set_size_limit(left.len() + right.len())?;
150    Ok(left.iter().chain(right.iter()).cloned().collect())
151}
152
153/// Compute the cross product of two grouping_sets
154///
155/// # Example
156/// ```text
157/// [(A, B), (C, D)], [(E), (F)] -> [(A, B, E), (A, B, F), (C, D, E), (C, D, F)]
158/// ```
159///
160/// # Error
161/// - [`DataFusionError`]: The number of group_expression in grouping_set exceeds the maximum limit
162/// - [`DataFusionError`]: The number of grouping_set in grouping_sets exceeds the maximum limit
163///
164/// [`DataFusionError`]: datafusion_common::DataFusionError
165fn cross_join_grouping_sets<T: Clone>(
166    left: &[Vec<T>],
167    right: &[Vec<T>],
168) -> Result<Vec<Vec<T>>> {
169    let grouping_sets_size = left.len() * right.len();
170
171    check_grouping_sets_size_limit(grouping_sets_size)?;
172
173    let mut result = Vec::with_capacity(grouping_sets_size);
174    for le in left {
175        for re in right {
176            result.push(merge_grouping_set(le, re)?);
177        }
178    }
179    Ok(result)
180}
181
182/// Convert multiple grouping expressions into one [`GroupingSet::GroupingSets`],\
183/// if the grouping expression does not contain [`Expr::GroupingSet`] or only has one expression,\
184/// no conversion will be performed.
185///
186/// e.g.
187///
188/// person.id,\
189/// GROUPING SETS ((person.age, person.salary),(person.age)),\
190/// ROLLUP(person.state, person.birth_date)
191///
192/// =>
193///
194/// GROUPING SETS (\
195///   (person.id, person.age, person.salary),\
196///   (person.id, person.age, person.salary, person.state),\
197///   (person.id, person.age, person.salary, person.state, person.birth_date),\
198///   (person.id, person.age),\
199///   (person.id, person.age, person.state),\
200///   (person.id, person.age, person.state, person.birth_date)\
201/// )
202pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
203    let has_grouping_set = group_expr
204        .iter()
205        .any(|expr| matches!(expr, Expr::GroupingSet(_)));
206    if !has_grouping_set || group_expr.len() == 1 {
207        return Ok(group_expr);
208    }
209    // Only process mix grouping sets
210    let partial_sets = group_expr
211        .iter()
212        .map(|expr| {
213            let exprs = match expr {
214                Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
215                    check_grouping_sets_size_limit(grouping_sets.len())?;
216                    grouping_sets.iter().map(|e| e.iter().collect()).collect()
217                }
218                Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
219                    let grouping_sets = powerset(group_exprs)?;
220                    check_grouping_sets_size_limit(grouping_sets.len())?;
221                    grouping_sets
222                }
223                Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
224                    let size = group_exprs.len();
225                    let slice = group_exprs.as_slice();
226                    check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
227                    (0..(size + 1))
228                        .map(|i| slice[0..i].iter().collect())
229                        .collect()
230                }
231                expr => vec![vec![expr]],
232            };
233            Ok(exprs)
234        })
235        .collect::<Result<Vec<_>>>()?;
236
237    // Cross Join
238    let grouping_sets = partial_sets
239        .into_iter()
240        .map(Ok)
241        .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
242        .transpose()?
243        .map(|e| {
244            e.into_iter()
245                .map(|e| e.into_iter().cloned().collect())
246                .collect()
247        })
248        .unwrap_or_default();
249
250    Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
251        grouping_sets,
252    ))])
253}
254
255/// Find all distinct exprs in a list of group by expressions. If the
256/// first element is a `GroupingSet` expression then it must be the only expr.
257pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
258    if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
259        if group_expr.len() > 1 {
260            return plan_err!(
261                "Invalid group by expressions, GroupingSet must be the only expression"
262            );
263        }
264        Ok(grouping_set.distinct_expr())
265    } else {
266        Ok(group_expr
267            .iter()
268            .collect::<IndexSet<_>>()
269            .into_iter()
270            .collect())
271    }
272}
273
274/// Recursively walk an expression tree, collecting the unique set of columns
275/// referenced in the expression
276pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
277    expr.apply(|expr| {
278        match expr {
279            Expr::Column(qc) => {
280                accum.insert(qc.clone());
281            }
282            // Use explicit pattern match instead of a default
283            // implementation, so that in the future if someone adds
284            // new Expr types, they will check here as well
285            // TODO: remove the next line after `Expr::Wildcard` is removed
286            #[expect(deprecated)]
287            Expr::Unnest(_)
288            | Expr::ScalarVariable(_, _)
289            | Expr::Alias(_)
290            | Expr::Literal(_, _)
291            | Expr::BinaryExpr { .. }
292            | Expr::Like { .. }
293            | Expr::SimilarTo { .. }
294            | Expr::Not(_)
295            | Expr::IsNotNull(_)
296            | Expr::IsNull(_)
297            | Expr::IsTrue(_)
298            | Expr::IsFalse(_)
299            | Expr::IsUnknown(_)
300            | Expr::IsNotTrue(_)
301            | Expr::IsNotFalse(_)
302            | Expr::IsNotUnknown(_)
303            | Expr::Negative(_)
304            | Expr::Between { .. }
305            | Expr::Case { .. }
306            | Expr::Cast { .. }
307            | Expr::TryCast { .. }
308            | Expr::ScalarFunction(..)
309            | Expr::WindowFunction { .. }
310            | Expr::AggregateFunction { .. }
311            | Expr::GroupingSet(_)
312            | Expr::InList { .. }
313            | Expr::Exists { .. }
314            | Expr::InSubquery(_)
315            | Expr::SetComparison(_)
316            | Expr::ScalarSubquery(_)
317            | Expr::Wildcard { .. }
318            | Expr::Placeholder(_)
319            | Expr::OuterReferenceColumn { .. } => {}
320        }
321        Ok(TreeNodeRecursion::Continue)
322    })
323    .map(|_| ())
324}
325
326/// Find excluded columns in the schema, if any
327/// SELECT * EXCLUDE(col1, col2), would return `vec![col1, col2]`
328fn get_excluded_columns(
329    opt_exclude: Option<&ExcludeSelectItem>,
330    opt_except: Option<&ExceptSelectItem>,
331    schema: &DFSchema,
332    qualifier: Option<&TableReference>,
333) -> Result<Vec<Column>> {
334    let mut idents = vec![];
335    if let Some(excepts) = opt_except {
336        idents.push(&excepts.first_element);
337        idents.extend(&excepts.additional_elements);
338    }
339    if let Some(exclude) = opt_exclude {
340        match exclude {
341            ExcludeSelectItem::Single(ident) => idents.push(ident),
342            ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
343        }
344    }
345    // Excluded columns should be unique
346    let n_elem = idents.len();
347    let unique_idents = idents.into_iter().collect::<HashSet<_>>();
348    // If HashSet size, and vector length are different, this means that some of the excluded columns
349    // are not unique. In this case return error.
350    if n_elem != unique_idents.len() {
351        return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
352    }
353
354    let mut result = vec![];
355    for ident in unique_idents.into_iter() {
356        let col_name = ident.value.as_str();
357        let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
358        result.push(Column::from((qualifier, field)));
359    }
360    Ok(result)
361}
362
363/// Returns all `Expr`s in the schema, except the `Column`s in the `columns_to_skip`
364fn get_exprs_except_skipped(
365    schema: &DFSchema,
366    columns_to_skip: &HashSet<Column>,
367) -> Vec<Expr> {
368    if columns_to_skip.is_empty() {
369        schema.iter().map(Expr::from).collect::<Vec<Expr>>()
370    } else {
371        schema
372            .columns()
373            .iter()
374            .filter_map(|c| {
375                if !columns_to_skip.contains(c) {
376                    Some(Expr::Column(c.clone()))
377                } else {
378                    None
379                }
380            })
381            .collect::<Vec<Expr>>()
382    }
383}
384
385/// For each column specified in the USING JOIN condition, the JOIN plan outputs it twice
386/// (once for each join side), but an unqualified wildcard should include it only once.
387/// This function returns the columns that should be excluded.
388fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
389    let using_columns = plan.using_columns()?;
390    let excluded = using_columns
391        .into_iter()
392        // For each USING JOIN condition, only expand to one of each join column in projection
393        .flat_map(|cols| {
394            let mut cols = cols.into_iter().collect::<Vec<_>>();
395            // sort join columns to make sure we consistently keep the same
396            // qualified column
397            cols.sort();
398            let mut out_column_names: HashSet<String> = HashSet::new();
399            cols.into_iter().filter_map(move |c| {
400                if out_column_names.contains(&c.name) {
401                    Some(c)
402                } else {
403                    out_column_names.insert(c.name);
404                    None
405                }
406            })
407        })
408        .collect::<HashSet<_>>();
409    Ok(excluded)
410}
411
412/// Resolves an `Expr::Wildcard` to a collection of `Expr::Column`'s.
413pub fn expand_wildcard(
414    schema: &DFSchema,
415    plan: &LogicalPlan,
416    wildcard_options: Option<&WildcardOptions>,
417) -> Result<Vec<Expr>> {
418    let mut columns_to_skip = exclude_using_columns(plan)?;
419    let excluded_columns = if let Some(WildcardOptions {
420        exclude: opt_exclude,
421        except: opt_except,
422        ..
423    }) = wildcard_options
424    {
425        get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
426    } else {
427        vec![]
428    };
429    // Add each excluded `Column` to columns_to_skip
430    columns_to_skip.extend(excluded_columns);
431    Ok(get_exprs_except_skipped(schema, &columns_to_skip))
432}
433
434/// Resolves an `Expr::Wildcard` to a collection of qualified `Expr::Column`'s.
435pub fn expand_qualified_wildcard(
436    qualifier: &TableReference,
437    schema: &DFSchema,
438    wildcard_options: Option<&WildcardOptions>,
439) -> Result<Vec<Expr>> {
440    let qualified_indices = schema.fields_indices_with_qualified(qualifier);
441    let projected_func_dependencies = schema
442        .functional_dependencies()
443        .project_functional_dependencies(&qualified_indices, qualified_indices.len());
444    let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
445    if fields_with_qualified.is_empty() {
446        return plan_err!("Invalid qualifier {qualifier}");
447    }
448
449    let qualified_schema = Arc::new(Schema::new_with_metadata(
450        fields_with_qualified,
451        schema.metadata().clone(),
452    ));
453    let qualified_dfschema =
454        DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
455            .with_functional_dependencies(projected_func_dependencies)?;
456    let excluded_columns = if let Some(WildcardOptions {
457        exclude: opt_exclude,
458        except: opt_except,
459        ..
460    }) = wildcard_options
461    {
462        get_excluded_columns(
463            opt_exclude.as_ref(),
464            opt_except.as_ref(),
465            schema,
466            Some(qualifier),
467        )?
468    } else {
469        vec![]
470    };
471    // Add each excluded `Column` to columns_to_skip
472    let mut columns_to_skip = HashSet::new();
473    columns_to_skip.extend(excluded_columns);
474    Ok(get_exprs_except_skipped(
475        &qualified_dfschema,
476        &columns_to_skip,
477    ))
478}
479
480/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)")
481/// If bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column
482type WindowSortKey = Vec<(Sort, bool)>;
483
484/// Generate a sort key for a given window expr's partition_by and order_by expr
485pub fn generate_sort_key(
486    partition_by: &[Expr],
487    order_by: &[Sort],
488) -> Result<WindowSortKey> {
489    let normalized_order_by_keys = order_by
490        .iter()
491        .map(|e| {
492            let Sort { expr, .. } = e;
493            Sort::new(expr.clone(), true, false)
494        })
495        .collect::<Vec<_>>();
496
497    let mut final_sort_keys = vec![];
498    let mut is_partition_flag = vec![];
499    partition_by.iter().for_each(|e| {
500        // By default, create sort key with ASC is true and NULLS LAST to be consistent with
501        // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html
502        let e = e.clone().sort(true, false);
503        if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
504            let order_by_key = &order_by[pos];
505            if !final_sort_keys.contains(order_by_key) {
506                final_sort_keys.push(order_by_key.clone());
507                is_partition_flag.push(true);
508            }
509        } else if !final_sort_keys.contains(&e) {
510            final_sort_keys.push(e);
511            is_partition_flag.push(true);
512        }
513    });
514
515    order_by.iter().for_each(|e| {
516        if !final_sort_keys.contains(e) {
517            final_sort_keys.push(e.clone());
518            is_partition_flag.push(false);
519        }
520    });
521    let res = final_sort_keys
522        .into_iter()
523        .zip(is_partition_flag)
524        .collect::<Vec<_>>();
525    Ok(res)
526}
527
528/// Compare the sort expr as PostgreSQL's common_prefix_cmp():
529/// <https://github.com/postgres/postgres/blob/master/src/backend/optimizer/plan/planner.c>
530pub fn compare_sort_expr(
531    sort_expr_a: &Sort,
532    sort_expr_b: &Sort,
533    schema: &DFSchemaRef,
534) -> Ordering {
535    let Sort {
536        expr: expr_a,
537        asc: asc_a,
538        nulls_first: nulls_first_a,
539    } = sort_expr_a;
540
541    let Sort {
542        expr: expr_b,
543        asc: asc_b,
544        nulls_first: nulls_first_b,
545    } = sort_expr_b;
546
547    let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
548    let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
549    for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
550        match idx_a.cmp(idx_b) {
551            Ordering::Less => {
552                return Ordering::Less;
553            }
554            Ordering::Greater => {
555                return Ordering::Greater;
556            }
557            Ordering::Equal => {}
558        }
559    }
560    match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
561        Ordering::Less => return Ordering::Greater,
562        Ordering::Greater => {
563            return Ordering::Less;
564        }
565        Ordering::Equal => {}
566    }
567    match (asc_a, asc_b) {
568        (true, false) => {
569            return Ordering::Greater;
570        }
571        (false, true) => {
572            return Ordering::Less;
573        }
574        _ => {}
575    }
576    match (nulls_first_a, nulls_first_b) {
577        (true, false) => {
578            return Ordering::Less;
579        }
580        (false, true) => {
581            return Ordering::Greater;
582        }
583        _ => {}
584    }
585    Ordering::Equal
586}
587
588/// Group a slice of window expression expr by their order by expressions
589pub fn group_window_expr_by_sort_keys(
590    window_expr: impl IntoIterator<Item = Expr>,
591) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
592    let mut result = vec![];
593    window_expr.into_iter().try_for_each(|expr| match &expr {
594        Expr::WindowFunction(window_fun) => {
595            let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
596            let sort_key = generate_sort_key(partition_by, order_by)?;
597            if let Some((_, values)) = result.iter_mut().find(
598                |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
599            ) {
600                values.push(expr);
601            } else {
602                result.push((sort_key, vec![expr]))
603            }
604            Ok(())
605        }
606        other => internal_err!(
607            "Impossibly got non-window expr {other:?}"
608        ),
609    })?;
610    Ok(result)
611}
612
613/// Collect all deeply nested `Expr::AggregateFunction`.
614/// They are returned in order of occurrence (depth
615/// first), with duplicates omitted.
616pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
617    find_exprs_in_exprs(exprs, &|nested_expr| {
618        matches!(nested_expr, Expr::AggregateFunction { .. })
619    })
620}
621
622/// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence
623/// (depth first), with duplicates omitted.
624pub fn find_window_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
625    find_exprs_in_exprs(exprs, &|nested_expr| {
626        matches!(nested_expr, Expr::WindowFunction { .. })
627    })
628}
629
630/// Collect all deeply nested `Expr::OuterReferenceColumn`. They are returned in order of occurrence
631/// (depth first), with duplicates omitted.
632pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
633    find_exprs_in_expr(expr, &|nested_expr| {
634        matches!(nested_expr, Expr::OuterReferenceColumn { .. })
635    })
636}
637
638/// Search the provided `Expr`'s, and all of their nested `Expr`, for any that
639/// pass the provided test. The returned `Expr`'s are deduplicated and returned
640/// in order of appearance (depth first).
641fn find_exprs_in_exprs<'a, F>(
642    exprs: impl IntoIterator<Item = &'a Expr>,
643    test_fn: &F,
644) -> Vec<Expr>
645where
646    F: Fn(&Expr) -> bool,
647{
648    exprs
649        .into_iter()
650        .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
651        .fold(vec![], |mut acc, expr| {
652            if !acc.contains(&expr) {
653                acc.push(expr)
654            }
655            acc
656        })
657}
658
659/// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the
660/// provided test. The returned `Expr`'s are deduplicated and returned in order
661/// of appearance (depth first).
662fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
663where
664    F: Fn(&Expr) -> bool,
665{
666    let mut exprs = vec![];
667    expr.apply(|expr| {
668        if test_fn(expr) {
669            if !(exprs.contains(expr)) {
670                exprs.push(expr.clone())
671            }
672            // Stop recursing down this expr once we find a match
673            return Ok(TreeNodeRecursion::Jump);
674        }
675
676        Ok(TreeNodeRecursion::Continue)
677    })
678    // pre_visit always returns OK, so this will always too
679    .expect("no way to return error during recursion");
680    exprs
681}
682
683/// Recursively inspect an [`Expr`] and all its children.
684pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
685where
686    F: FnMut(&Expr) -> Result<(), E>,
687{
688    let mut err = Ok(());
689    expr.apply(|expr| {
690        if let Err(e) = f(expr) {
691            // Save the error for later (it may not be a DataFusionError)
692            err = Err(e);
693            Ok(TreeNodeRecursion::Stop)
694        } else {
695            // keep going
696            Ok(TreeNodeRecursion::Continue)
697        }
698    })
699    // The closure always returns OK, so this will always too
700    .expect("no way to return error during recursion");
701
702    err
703}
704
705/// Create schema fields from an expression list, for use in result set schema construction
706///
707/// This function converts a list of expressions into a list of complete schema fields,
708/// making comprehensive determinations about each field's properties including:
709/// - **Data type**: Resolved based on expression type and input schema context
710/// - **Nullability**: Determined by expression-specific nullability rules
711/// - **Metadata**: Computed based on expression type (preserving, merging, or generating new metadata)
712/// - **Table reference scoping**: Establishing proper qualified field references
713///
714/// Each expression is converted to a field by calling [`Expr::to_field`], which performs
715/// the complete field resolution process for all field properties.
716///
717/// # Returns
718///
719/// A `Result` containing a vector of `(Option<TableReference>, Arc<Field>)` tuples,
720/// where each Field contains complete schema information (type, nullability, metadata)
721/// and proper table reference scoping for the corresponding expression.
722pub fn exprlist_to_fields<'a>(
723    exprs: impl IntoIterator<Item = &'a Expr>,
724    plan: &LogicalPlan,
725) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
726    // Look for exact match in plan's output schema
727    let input_schema = plan.schema();
728    exprs
729        .into_iter()
730        .map(|e| e.to_field(input_schema))
731        .collect()
732}
733
734/// Convert an expression into Column expression if it's already provided as input plan.
735///
736/// For example, it rewrites:
737///
738/// ```text
739/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
740/// .project(vec![col("c1"), sum(col("c2"))?
741/// ```
742///
743/// Into:
744///
745/// ```text
746/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
747/// .project(vec![col("c1"), col("SUM(c2)")?
748/// ```
749pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
750    let output_exprs = match input.columnized_output_exprs() {
751        Ok(exprs) if !exprs.is_empty() => exprs,
752        _ => return Ok(e),
753    };
754    let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
755    e.transform_down(|node: Expr| match exprs_map.get(&node) {
756        Some(column) => Ok(Transformed::new(
757            Expr::Column(column.clone()),
758            true,
759            TreeNodeRecursion::Jump,
760        )),
761        None => Ok(Transformed::no(node)),
762    })
763    .data()
764}
765
766/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
767/// appearance (depth first), and may contain duplicates.
768pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
769    exprs
770        .iter()
771        .flat_map(find_columns_referenced_by_expr)
772        .map(Expr::Column)
773        .collect()
774}
775
776pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
777    let mut exprs = vec![];
778    e.apply(|expr| {
779        if let Expr::Column(c) = expr {
780            exprs.push(c.clone())
781        }
782        Ok(TreeNodeRecursion::Continue)
783    })
784    // As the closure always returns Ok, this "can't" error
785    .expect("Unexpected error");
786    exprs
787}
788
789/// Convert any `Expr` to an `Expr::Column`.
790pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
791    match expr {
792        Expr::Column(col) => {
793            let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
794            Ok(Expr::from(Column::from((qualifier, field))))
795        }
796        _ => Ok(Expr::Column(Column::from_name(
797            expr.schema_name().to_string(),
798        ))),
799    }
800}
801
802/// Recursively walk an expression tree, collecting the column indexes
803/// referenced in the expression
804pub(crate) fn find_column_indexes_referenced_by_expr(
805    e: &Expr,
806    schema: &DFSchemaRef,
807) -> Vec<usize> {
808    let mut indexes = vec![];
809    e.apply(|expr| {
810        match expr {
811            Expr::Column(qc) => {
812                if let Ok(idx) = schema.index_of_column(qc) {
813                    indexes.push(idx);
814                }
815            }
816            Expr::Literal(_, _) => {
817                indexes.push(usize::MAX);
818            }
819            _ => {}
820        }
821        Ok(TreeNodeRecursion::Continue)
822    })
823    .unwrap();
824    indexes
825}
826
827/// Can this data type be used in hash join equal conditions??
828/// Data types here come from function 'equal_rows', if more data types are supported
829/// in create_hashes, add those data types here to generate join logical plan.
830pub fn can_hash(data_type: &DataType) -> bool {
831    match data_type {
832        DataType::Null => true,
833        DataType::Boolean => true,
834        DataType::Int8 => true,
835        DataType::Int16 => true,
836        DataType::Int32 => true,
837        DataType::Int64 => true,
838        DataType::UInt8 => true,
839        DataType::UInt16 => true,
840        DataType::UInt32 => true,
841        DataType::UInt64 => true,
842        DataType::Float16 => true,
843        DataType::Float32 => true,
844        DataType::Float64 => true,
845        DataType::Decimal32(_, _) => true,
846        DataType::Decimal64(_, _) => true,
847        DataType::Decimal128(_, _) => true,
848        DataType::Decimal256(_, _) => true,
849        DataType::Timestamp(_, _) => true,
850        DataType::Utf8 => true,
851        DataType::LargeUtf8 => true,
852        DataType::Utf8View => true,
853        DataType::Binary => true,
854        DataType::LargeBinary => true,
855        DataType::BinaryView => true,
856        DataType::Date32 => true,
857        DataType::Date64 => true,
858        DataType::Time32(_) => true,
859        DataType::Time64(_) => true,
860        DataType::Duration(_) => true,
861        DataType::Interval(_) => true,
862        DataType::FixedSizeBinary(_) => true,
863        DataType::Dictionary(key_type, value_type) => {
864            DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
865        }
866        DataType::List(value_type) => can_hash(value_type.data_type()),
867        DataType::LargeList(value_type) => can_hash(value_type.data_type()),
868        DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
869        DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
870        DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
871
872        DataType::ListView(_)
873        | DataType::LargeListView(_)
874        | DataType::Union(_, _)
875        | DataType::RunEndEncoded(_, _) => false,
876    }
877}
878
879/// Check whether all columns are from the schema.
880pub fn check_all_columns_from_schema(
881    columns: &HashSet<&Column>,
882    schema: &DFSchema,
883) -> Result<bool> {
884    for col in columns.iter() {
885        let exist = schema.is_column_from_schema(col);
886        if !exist {
887            return Ok(false);
888        }
889    }
890
891    Ok(true)
892}
893
894/// Give two sides of the equijoin predicate, return a valid join key pair.
895/// If there is no valid join key pair, return None.
896///
897/// A valid join means:
898/// 1. All referenced column of the left side is from the left schema, and
899///    all referenced column of the right side is from the right schema.
900/// 2. Or opposite. All referenced column of the left side is from the right schema,
901///    and the right side is from the left schema.
902pub fn find_valid_equijoin_key_pair(
903    left_key: &Expr,
904    right_key: &Expr,
905    left_schema: &DFSchema,
906    right_schema: &DFSchema,
907) -> Result<Option<(Expr, Expr)>> {
908    let left_using_columns = left_key.column_refs();
909    let right_using_columns = right_key.column_refs();
910
911    // Conditions like a = 10, will be added to non-equijoin.
912    if left_using_columns.is_empty() || right_using_columns.is_empty() {
913        return Ok(None);
914    }
915
916    if check_all_columns_from_schema(&left_using_columns, left_schema)?
917        && check_all_columns_from_schema(&right_using_columns, right_schema)?
918    {
919        return Ok(Some((left_key.clone(), right_key.clone())));
920    } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
921        && check_all_columns_from_schema(&left_using_columns, right_schema)?
922    {
923        return Ok(Some((right_key.clone(), left_key.clone())));
924    }
925
926    Ok(None)
927}
928
929/// Creates a detailed error message for a function with wrong signature.
930///
931/// For example, a query like `select round(3.14, 1.1);` would yield:
932/// ```text
933/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
934///     Candidate functions:
935///     round(Float64, Int64)
936///     round(Float32, Int64)
937///     round(Float64)
938///     round(Float32)
939/// ```
940#[expect(clippy::needless_pass_by_value)]
941#[deprecated(since = "53.0.0", note = "Internal function")]
942pub fn generate_signature_error_msg(
943    func_name: &str,
944    func_signature: Signature,
945    input_expr_types: &[DataType],
946) -> String {
947    let candidate_signatures = func_signature
948        .type_signature
949        .to_string_repr_with_names(func_signature.parameter_names.as_deref())
950        .iter()
951        .map(|args_str| format!("\t{func_name}({args_str})"))
952        .collect::<Vec<String>>()
953        .join("\n");
954
955    format!(
956        "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
957        func_name,
958        TypeSignature::join_types(input_expr_types, ", "),
959        candidate_signatures
960    )
961}
962
963/// Creates a detailed error message for a function with wrong signature.
964///
965/// For example, a query like `select round(3.14, 1.1);` would yield:
966/// ```text
967/// Error during planning: No function matches 'round(Float64, Float64)'. You might need to add explicit type casts.
968///     Candidate functions:
969///     round(Float64, Int64)
970///     round(Float32, Int64)
971///     round(Float64)
972///     round(Float32)
973/// ```
974pub(crate) fn generate_signature_error_message(
975    func_name: &str,
976    func_signature: &Signature,
977    input_expr_types: &[DataType],
978) -> String {
979    #[expect(deprecated)]
980    generate_signature_error_msg(func_name, func_signature.clone(), input_expr_types)
981}
982
983/// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
984///
985/// See [`split_conjunction_owned`] for more details and an example.
986pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
987    split_conjunction_impl(expr, vec![])
988}
989
990fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
991    match expr {
992        Expr::BinaryExpr(BinaryExpr {
993            right,
994            op: Operator::And,
995            left,
996        }) => {
997            let exprs = split_conjunction_impl(left, exprs);
998            split_conjunction_impl(right, exprs)
999        }
1000        Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
1001        other => {
1002            exprs.push(other);
1003            exprs
1004        }
1005    }
1006}
1007
1008/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1009///
1010/// See [`split_conjunction_owned`] for more details and an example.
1011pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
1012    let mut stack = vec![expr];
1013    std::iter::from_fn(move || {
1014        while let Some(expr) = stack.pop() {
1015            match expr {
1016                Expr::BinaryExpr(BinaryExpr {
1017                    right,
1018                    op: Operator::And,
1019                    left,
1020                }) => {
1021                    stack.push(right);
1022                    stack.push(left);
1023                }
1024                Expr::Alias(Alias { expr, .. }) => stack.push(expr),
1025                other => return Some(other),
1026            }
1027        }
1028        None
1029    })
1030}
1031
1032/// Iterate parts in a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1033///
1034/// See [`split_conjunction_owned`] for more details and an example.
1035pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
1036    let mut stack = vec![expr];
1037    std::iter::from_fn(move || {
1038        while let Some(expr) = stack.pop() {
1039            match expr {
1040                Expr::BinaryExpr(BinaryExpr {
1041                    right,
1042                    op: Operator::And,
1043                    left,
1044                }) => {
1045                    stack.push(*right);
1046                    stack.push(*left);
1047                }
1048                Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
1049                other => return Some(other),
1050            }
1051        }
1052        None
1053    })
1054}
1055
1056/// Splits an owned conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]`
1057///
1058/// This is often used to "split" filter expressions such as `col1 = 5
1059/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1060///
1061/// # Example
1062/// ```
1063/// # use datafusion_expr::{col, lit};
1064/// # use datafusion_expr::utils::split_conjunction_owned;
1065/// // a=1 AND b=2
1066/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1067///
1068/// // [a=1, b=2]
1069/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1070///
1071/// // use split_conjunction_owned to split them
1072/// assert_eq!(split_conjunction_owned(expr), split);
1073/// ```
1074pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1075    split_binary_owned(expr, Operator::And)
1076}
1077
1078/// Splits an owned binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1079///
1080/// This is often used to "split" expressions such as `col1 = 5
1081/// AND col2 = 10` into [`col1 = 5`, `col2 = 10`];
1082///
1083/// # Example
1084/// ```
1085/// # use datafusion_expr::{col, lit, Operator};
1086/// # use datafusion_expr::utils::split_binary_owned;
1087/// # use std::ops::Add;
1088/// // a=1 + b=2
1089/// let expr = col("a").eq(lit(1)).add(col("b").eq(lit(2)));
1090///
1091/// // [a=1, b=2]
1092/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1093///
1094/// // use split_binary_owned to split them
1095/// assert_eq!(split_binary_owned(expr, Operator::Plus), split);
1096/// ```
1097pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1098    split_binary_owned_impl(expr, op, vec![])
1099}
1100
1101fn split_binary_owned_impl(
1102    expr: Expr,
1103    operator: Operator,
1104    mut exprs: Vec<Expr>,
1105) -> Vec<Expr> {
1106    match expr {
1107        Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1108            let exprs = split_binary_owned_impl(*left, operator, exprs);
1109            split_binary_owned_impl(*right, operator, exprs)
1110        }
1111        Expr::Alias(Alias { expr, .. }) => {
1112            split_binary_owned_impl(*expr, operator, exprs)
1113        }
1114        other => {
1115            exprs.push(other);
1116            exprs
1117        }
1118    }
1119}
1120
1121/// Splits an binary operator tree [`Expr`] such as `A <OP> B <OP> C` => `[A, B, C]`
1122///
1123/// See [`split_binary_owned`] for more details and an example.
1124pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1125    split_binary_impl(expr, op, vec![])
1126}
1127
1128fn split_binary_impl<'a>(
1129    expr: &'a Expr,
1130    operator: Operator,
1131    mut exprs: Vec<&'a Expr>,
1132) -> Vec<&'a Expr> {
1133    match expr {
1134        Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1135            let exprs = split_binary_impl(left, operator, exprs);
1136            split_binary_impl(right, operator, exprs)
1137        }
1138        Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1139        other => {
1140            exprs.push(other);
1141            exprs
1142        }
1143    }
1144}
1145
1146/// Combines an array of filter expressions into a single filter
1147/// expression consisting of the input filter expressions joined with
1148/// logical AND.
1149///
1150/// Returns None if the filters array is empty.
1151///
1152/// # Example
1153/// ```
1154/// # use datafusion_expr::{col, lit};
1155/// # use datafusion_expr::utils::conjunction;
1156/// // a=1 AND b=2
1157/// let expr = col("a").eq(lit(1)).and(col("b").eq(lit(2)));
1158///
1159/// // [a=1, b=2]
1160/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1161///
1162/// // use conjunction to join them together with `AND`
1163/// assert_eq!(conjunction(split), Some(expr));
1164/// ```
1165pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1166    filters.into_iter().reduce(Expr::and)
1167}
1168
1169/// Combines an array of filter expressions into a single filter
1170/// expression consisting of the input filter expressions joined with
1171/// logical OR.
1172///
1173/// Returns None if the filters array is empty.
1174///
1175/// # Example
1176/// ```
1177/// # use datafusion_expr::{col, lit};
1178/// # use datafusion_expr::utils::disjunction;
1179/// // a=1 OR b=2
1180/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2)));
1181///
1182/// // [a=1, b=2]
1183/// let split = vec![col("a").eq(lit(1)), col("b").eq(lit(2))];
1184///
1185/// // use disjunction to join them together with `OR`
1186/// assert_eq!(disjunction(split), Some(expr));
1187/// ```
1188pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1189    filters.into_iter().reduce(Expr::or)
1190}
1191
1192/// Returns a new [LogicalPlan] that filters the output of  `plan` with a
1193/// [LogicalPlan::Filter] with all `predicates` ANDed.
1194///
1195/// # Example
1196/// Before:
1197/// ```text
1198/// plan
1199/// ```
1200///
1201/// After:
1202/// ```text
1203/// Filter(predicate)
1204///   plan
1205/// ```
1206pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1207    // reduce filters to a single filter with an AND
1208    let predicate = predicates
1209        .iter()
1210        .skip(1)
1211        .fold(predicates[0].clone(), |acc, predicate| {
1212            and(acc, (*predicate).to_owned())
1213        });
1214
1215    Ok(LogicalPlan::Filter(Filter::try_new(
1216        predicate,
1217        Arc::new(plan),
1218    )?))
1219}
1220
1221/// Looks for correlating expressions: for example, a binary expression with one field from the subquery, and
1222/// one not in the subquery (closed upon from outer scope)
1223///
1224/// # Arguments
1225///
1226/// * `exprs` - List of expressions that may or may not be joins
1227///
1228/// # Return value
1229///
1230/// Tuple of (expressions containing joins, remaining non-join expressions)
1231pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1232    let mut joins = vec![];
1233    let mut others = vec![];
1234    for filter in exprs.into_iter() {
1235        // If the expression contains correlated predicates, add it to join filters
1236        if filter.contains_outer() {
1237            if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1238            {
1239                joins.push(strip_outer_reference((*filter).clone()));
1240            }
1241        } else {
1242            others.push((*filter).clone());
1243        }
1244    }
1245
1246    Ok((joins, others))
1247}
1248
1249/// Returns the first (and only) element in a slice, or an error
1250///
1251/// # Arguments
1252///
1253/// * `slice` - The slice to extract from
1254///
1255/// # Return value
1256///
1257/// The first element, or an error
1258pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1259    match slice {
1260        [it] => Ok(it),
1261        [] => plan_err!("No items found!"),
1262        _ => plan_err!("More than one item found!"),
1263    }
1264}
1265
1266/// merge inputs schema into a single schema.
1267///
1268/// This function merges schemas from multiple logical plan inputs using [`DFSchema::merge`].
1269/// Refer to that documentation for details on precedence and metadata handling.
1270pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1271    if inputs.len() == 1 {
1272        inputs[0].schema().as_ref().clone()
1273    } else {
1274        inputs.iter().map(|input| input.schema()).fold(
1275            DFSchema::empty(),
1276            |mut lhs, rhs| {
1277                lhs.merge(rhs);
1278                lhs
1279            },
1280        )
1281    }
1282}
1283
1284/// Build state name. State is the intermediate state of the aggregate function.
1285pub fn format_state_name(name: &str, state_name: &str) -> String {
1286    format!("{name}[{state_name}]")
1287}
1288
1289/// Determine the set of [`Column`]s produced by the subquery.
1290pub fn collect_subquery_cols(
1291    exprs: &[Expr],
1292    subquery_schema: &DFSchema,
1293) -> Result<BTreeSet<Column>> {
1294    exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1295        let mut using_cols: Vec<Column> = vec![];
1296        for col in expr.column_refs().into_iter() {
1297            if subquery_schema.has_column(col) {
1298                using_cols.push(col.clone());
1299            }
1300        }
1301
1302        cols.extend(using_cols);
1303        Result::<_>::Ok(cols)
1304    })
1305}
1306
1307#[cfg(test)]
1308mod tests {
1309    use super::*;
1310    use crate::{
1311        Cast, ExprFunctionExt, WindowFunctionDefinition, col, cube,
1312        expr::WindowFunction,
1313        expr_vec_fmt, grouping_set, lit, rollup,
1314        test::function_stub::{max_udaf, min_udaf, sum_udaf},
1315    };
1316    use arrow::datatypes::{UnionFields, UnionMode};
1317    use datafusion_expr_common::signature::{TypeSignature, Volatility};
1318
1319    #[test]
1320    fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1321        let result = group_window_expr_by_sort_keys(vec![])?;
1322        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1323        assert_eq!(expected, result);
1324        Ok(())
1325    }
1326
1327    #[test]
1328    fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1329        let max1 = Expr::from(WindowFunction::new(
1330            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1331            vec![col("name")],
1332        ));
1333        let max2 = Expr::from(WindowFunction::new(
1334            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1335            vec![col("name")],
1336        ));
1337        let min3 = Expr::from(WindowFunction::new(
1338            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1339            vec![col("name")],
1340        ));
1341        let sum4 = Expr::from(WindowFunction::new(
1342            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1343            vec![col("age")],
1344        ));
1345        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1346        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1347        let key = vec![];
1348        let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1349            vec![(key, vec![max1, max2, min3, sum4])];
1350        assert_eq!(expected, result);
1351        Ok(())
1352    }
1353
1354    #[test]
1355    fn test_group_window_expr_by_sort_keys() -> Result<()> {
1356        let age_asc = Sort::new(col("age"), true, true);
1357        let name_desc = Sort::new(col("name"), false, true);
1358        let created_at_desc = Sort::new(col("created_at"), false, true);
1359        let max1 = Expr::from(WindowFunction::new(
1360            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1361            vec![col("name")],
1362        ))
1363        .order_by(vec![age_asc.clone(), name_desc.clone()])
1364        .build()
1365        .unwrap();
1366        let max2 = Expr::from(WindowFunction::new(
1367            WindowFunctionDefinition::AggregateUDF(max_udaf()),
1368            vec![col("name")],
1369        ));
1370        let min3 = Expr::from(WindowFunction::new(
1371            WindowFunctionDefinition::AggregateUDF(min_udaf()),
1372            vec![col("name")],
1373        ))
1374        .order_by(vec![age_asc.clone(), name_desc.clone()])
1375        .build()
1376        .unwrap();
1377        let sum4 = Expr::from(WindowFunction::new(
1378            WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1379            vec![col("age")],
1380        ))
1381        .order_by(vec![
1382            name_desc.clone(),
1383            age_asc.clone(),
1384            created_at_desc.clone(),
1385        ])
1386        .build()
1387        .unwrap();
1388        // FIXME use as_ref
1389        let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1390        let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1391
1392        let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1393        let key2 = vec![];
1394        let key3 = vec![
1395            (name_desc, false),
1396            (age_asc, false),
1397            (created_at_desc, false),
1398        ];
1399
1400        let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1401            (key1, vec![max1, min3]),
1402            (key2, vec![max2]),
1403            (key3, vec![sum4]),
1404        ];
1405        assert_eq!(expected, result);
1406        Ok(())
1407    }
1408
1409    #[test]
1410    fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1411        let asc_or_desc = [true, false];
1412        let nulls_first_or_last = [true, false];
1413        let partition_by = &[col("age"), col("name"), col("created_at")];
1414        for asc_ in asc_or_desc {
1415            for nulls_first_ in nulls_first_or_last {
1416                let order_by = &[
1417                    Sort {
1418                        expr: col("age"),
1419                        asc: asc_,
1420                        nulls_first: nulls_first_,
1421                    },
1422                    Sort {
1423                        expr: col("name"),
1424                        asc: asc_,
1425                        nulls_first: nulls_first_,
1426                    },
1427                ];
1428
1429                let expected = vec![
1430                    (
1431                        Sort {
1432                            expr: col("age"),
1433                            asc: asc_,
1434                            nulls_first: nulls_first_,
1435                        },
1436                        true,
1437                    ),
1438                    (
1439                        Sort {
1440                            expr: col("name"),
1441                            asc: asc_,
1442                            nulls_first: nulls_first_,
1443                        },
1444                        true,
1445                    ),
1446                    (
1447                        Sort {
1448                            expr: col("created_at"),
1449                            asc: true,
1450                            nulls_first: false,
1451                        },
1452                        true,
1453                    ),
1454                ];
1455                let result = generate_sort_key(partition_by, order_by)?;
1456                assert_eq!(expected, result);
1457            }
1458        }
1459        Ok(())
1460    }
1461
1462    #[test]
1463    fn test_enumerate_grouping_sets() -> Result<()> {
1464        let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1465        let simple_col = col("simple_col");
1466        let cube = cube(multi_cols.clone());
1467        let rollup = rollup(multi_cols.clone());
1468        let grouping_set = grouping_set(vec![multi_cols]);
1469
1470        // 1. col
1471        let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1472        let result = format!("[{}]", expr_vec_fmt!(sets));
1473        assert_eq!("[simple_col]", &result);
1474
1475        // 2. cube
1476        let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1477        let result = format!("[{}]", expr_vec_fmt!(sets));
1478        assert_eq!("[CUBE (col1, col2, col3)]", &result);
1479
1480        // 3. rollup
1481        let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1482        let result = format!("[{}]", expr_vec_fmt!(sets));
1483        assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1484
1485        // 4. col + cube
1486        let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1487        let result = format!("[{}]", expr_vec_fmt!(sets));
1488        assert_eq!(
1489            "[GROUPING SETS (\
1490            (simple_col), \
1491            (simple_col, col1), \
1492            (simple_col, col2), \
1493            (simple_col, col1, col2), \
1494            (simple_col, col3), \
1495            (simple_col, col1, col3), \
1496            (simple_col, col2, col3), \
1497            (simple_col, col1, col2, col3))]",
1498            &result
1499        );
1500
1501        // 5. col + rollup
1502        let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1503        let result = format!("[{}]", expr_vec_fmt!(sets));
1504        assert_eq!(
1505            "[GROUPING SETS (\
1506            (simple_col), \
1507            (simple_col, col1), \
1508            (simple_col, col1, col2), \
1509            (simple_col, col1, col2, col3))]",
1510            &result
1511        );
1512
1513        // 6. col + grouping_set
1514        let sets =
1515            enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1516        let result = format!("[{}]", expr_vec_fmt!(sets));
1517        assert_eq!(
1518            "[GROUPING SETS (\
1519            (simple_col, col1, col2, col3))]",
1520            &result
1521        );
1522
1523        // 7. col + grouping_set + rollup
1524        let sets = enumerate_grouping_sets(vec![
1525            simple_col.clone(),
1526            grouping_set,
1527            rollup.clone(),
1528        ])?;
1529        let result = format!("[{}]", expr_vec_fmt!(sets));
1530        assert_eq!(
1531            "[GROUPING SETS (\
1532            (simple_col, col1, col2, col3), \
1533            (simple_col, col1, col2, col3, col1), \
1534            (simple_col, col1, col2, col3, col1, col2), \
1535            (simple_col, col1, col2, col3, col1, col2, col3))]",
1536            &result
1537        );
1538
1539        // 8. col + cube + rollup
1540        let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1541        let result = format!("[{}]", expr_vec_fmt!(sets));
1542        assert_eq!(
1543            "[GROUPING SETS (\
1544            (simple_col), \
1545            (simple_col, col1), \
1546            (simple_col, col1, col2), \
1547            (simple_col, col1, col2, col3), \
1548            (simple_col, col1), \
1549            (simple_col, col1, col1), \
1550            (simple_col, col1, col1, col2), \
1551            (simple_col, col1, col1, col2, col3), \
1552            (simple_col, col2), \
1553            (simple_col, col2, col1), \
1554            (simple_col, col2, col1, col2), \
1555            (simple_col, col2, col1, col2, col3), \
1556            (simple_col, col1, col2), \
1557            (simple_col, col1, col2, col1), \
1558            (simple_col, col1, col2, col1, col2), \
1559            (simple_col, col1, col2, col1, col2, col3), \
1560            (simple_col, col3), \
1561            (simple_col, col3, col1), \
1562            (simple_col, col3, col1, col2), \
1563            (simple_col, col3, col1, col2, col3), \
1564            (simple_col, col1, col3), \
1565            (simple_col, col1, col3, col1), \
1566            (simple_col, col1, col3, col1, col2), \
1567            (simple_col, col1, col3, col1, col2, col3), \
1568            (simple_col, col2, col3), \
1569            (simple_col, col2, col3, col1), \
1570            (simple_col, col2, col3, col1, col2), \
1571            (simple_col, col2, col3, col1, col2, col3), \
1572            (simple_col, col1, col2, col3), \
1573            (simple_col, col1, col2, col3, col1), \
1574            (simple_col, col1, col2, col3, col1, col2), \
1575            (simple_col, col1, col2, col3, col1, col2, col3))]",
1576            &result
1577        );
1578
1579        Ok(())
1580    }
1581    #[test]
1582    fn test_split_conjunction() {
1583        let expr = col("a");
1584        let result = split_conjunction(&expr);
1585        assert_eq!(result, vec![&expr]);
1586    }
1587
1588    #[test]
1589    fn test_split_conjunction_two() {
1590        let expr = col("a").eq(lit(5)).and(col("b"));
1591        let expr1 = col("a").eq(lit(5));
1592        let expr2 = col("b");
1593
1594        let result = split_conjunction(&expr);
1595        assert_eq!(result, vec![&expr1, &expr2]);
1596    }
1597
1598    #[test]
1599    fn test_split_conjunction_alias() {
1600        let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1601        let expr1 = col("a").eq(lit(5));
1602        let expr2 = col("b"); // has no alias
1603
1604        let result = split_conjunction(&expr);
1605        assert_eq!(result, vec![&expr1, &expr2]);
1606    }
1607
1608    #[test]
1609    fn test_split_conjunction_or() {
1610        let expr = col("a").eq(lit(5)).or(col("b"));
1611        let result = split_conjunction(&expr);
1612        assert_eq!(result, vec![&expr]);
1613    }
1614
1615    #[test]
1616    fn test_split_binary_owned() {
1617        let expr = col("a");
1618        assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1619    }
1620
1621    #[test]
1622    fn test_split_binary_owned_two() {
1623        assert_eq!(
1624            split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1625            vec![col("a").eq(lit(5)), col("b")]
1626        );
1627    }
1628
1629    #[test]
1630    fn test_split_binary_owned_different_op() {
1631        let expr = col("a").eq(lit(5)).or(col("b"));
1632        assert_eq!(
1633            // expr is connected by OR, but pass in AND
1634            split_binary_owned(expr.clone(), Operator::And),
1635            vec![expr]
1636        );
1637    }
1638
1639    #[test]
1640    fn test_split_conjunction_owned() {
1641        let expr = col("a");
1642        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1643    }
1644
1645    #[test]
1646    fn test_split_conjunction_owned_two() {
1647        assert_eq!(
1648            split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1649            vec![col("a").eq(lit(5)), col("b")]
1650        );
1651    }
1652
1653    #[test]
1654    fn test_split_conjunction_owned_alias() {
1655        assert_eq!(
1656            split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1657            vec![
1658                col("a").eq(lit(5)),
1659                // no alias on b
1660                col("b"),
1661            ]
1662        );
1663    }
1664
1665    #[test]
1666    fn test_conjunction_empty() {
1667        assert_eq!(conjunction(vec![]), None);
1668    }
1669
1670    #[test]
1671    fn test_conjunction() {
1672        // `[A, B, C]`
1673        let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1674
1675        // --> `(A AND B) AND C`
1676        assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1677
1678        // which is different than `A AND (B AND C)`
1679        assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1680    }
1681
1682    #[test]
1683    fn test_disjunction_empty() {
1684        assert_eq!(disjunction(vec![]), None);
1685    }
1686
1687    #[test]
1688    fn test_disjunction() {
1689        // `[A, B, C]`
1690        let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1691
1692        // --> `(A OR B) OR C`
1693        assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1694
1695        // which is different than `A OR (B OR C)`
1696        assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1697    }
1698
1699    #[test]
1700    fn test_split_conjunction_owned_or() {
1701        let expr = col("a").eq(lit(5)).or(col("b"));
1702        assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1703    }
1704
1705    #[test]
1706    fn test_collect_expr() -> Result<()> {
1707        let mut accum: HashSet<Column> = HashSet::new();
1708        expr_to_columns(
1709            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1710            &mut accum,
1711        )?;
1712        expr_to_columns(
1713            &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1714            &mut accum,
1715        )?;
1716        assert_eq!(1, accum.len());
1717        assert!(accum.contains(&Column::from_name("a")));
1718        Ok(())
1719    }
1720
1721    #[test]
1722    fn test_can_hash() {
1723        let union_fields: UnionFields = [
1724            (0, Arc::new(Field::new("A", DataType::Int32, true))),
1725            (1, Arc::new(Field::new("B", DataType::Float64, true))),
1726        ]
1727        .into_iter()
1728        .collect();
1729
1730        let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1731        assert!(!can_hash(&union_type));
1732
1733        let list_union_type =
1734            DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1735        assert!(!can_hash(&list_union_type));
1736    }
1737
1738    #[test]
1739    fn test_generate_signature_error_msg_with_parameter_names() {
1740        let sig = Signature::one_of(
1741            vec![
1742                TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
1743                TypeSignature::Exact(vec![
1744                    DataType::Utf8,
1745                    DataType::Int64,
1746                    DataType::Int64,
1747                ]),
1748            ],
1749            Volatility::Immutable,
1750        )
1751        .with_parameter_names(vec![
1752            "str".to_string(),
1753            "start_pos".to_string(),
1754            "length".to_string(),
1755        ])
1756        .expect("valid parameter names");
1757
1758        // Generate error message with only 1 argument provided
1759        let error_msg =
1760            generate_signature_error_message("substr", &sig, &[DataType::Utf8]);
1761
1762        assert!(
1763            error_msg.contains("str: Utf8, start_pos: Int64"),
1764            "Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}"
1765        );
1766        assert!(
1767            error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"),
1768            "Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}"
1769        );
1770    }
1771
1772    #[test]
1773    fn test_generate_signature_error_msg_without_parameter_names() {
1774        let sig = Signature::one_of(
1775            vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1776            Volatility::Immutable,
1777        );
1778
1779        let error_msg =
1780            generate_signature_error_message("my_func", &sig, &[DataType::Int32]);
1781
1782        assert!(
1783            error_msg.contains("Any, Any"),
1784            "Expected 'Any, Any' without parameter names, got: {error_msg}"
1785        );
1786    }
1787}