1use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27 BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator, and,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37 Column, DFSchema, DFSchemaRef, HashMap, Result, TableReference, internal_err,
38 plan_err,
39};
40
41#[cfg(not(feature = "sql"))]
42use crate::sql::{ExceptSelectItem, ExcludeSelectItem, Ident, ObjectName};
43use indexmap::IndexSet;
44#[cfg(feature = "sql")]
45use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem, Ident, ObjectName};
46
47pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
48
49pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
52
53pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
56 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
57 if group_expr.len() > 1 {
58 return plan_err!(
59 "Invalid group by expressions, GroupingSet must be the only expression"
60 );
61 }
62 Ok(grouping_set.distinct_expr().len() + 1)
64 } else {
65 grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
66 }
67}
68
69fn powerset_indices(len: usize) -> impl Iterator<Item = Vec<usize>> {
73 (0..(1 << len)).map(move |mask| {
74 let mut indices = vec![];
75 let mut bitset = mask;
76 while bitset > 0 {
77 let rightmost: u64 = bitset & !(bitset - 1);
78 let idx = rightmost.trailing_zeros() as usize;
79 indices.push(idx);
80 bitset &= bitset - 1;
81 }
82 indices
83 })
84}
85
86pub fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>> {
104 if slice.len() >= 64 {
105 return plan_err!("The size of the set must be less than 64");
106 }
107
108 Ok(powerset_indices(slice.len())
109 .map(|indices| indices.iter().map(|&idx| &slice[idx]).collect())
110 .collect())
111}
112
113fn check_grouping_set_size_limit(size: usize) -> Result<()> {
115 let max_grouping_set_size = 65535;
116 if size > max_grouping_set_size {
117 return plan_err!(
118 "The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}"
119 );
120 }
121
122 Ok(())
123}
124
125fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
127 let max_grouping_sets_size = 4096;
128 if size > max_grouping_sets_size {
129 return plan_err!(
130 "The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}"
131 );
132 }
133
134 Ok(())
135}
136
137fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
149 check_grouping_set_size_limit(left.len() + right.len())?;
150 Ok(left.iter().chain(right.iter()).cloned().collect())
151}
152
153fn cross_join_grouping_sets<T: Clone>(
166 left: &[Vec<T>],
167 right: &[Vec<T>],
168) -> Result<Vec<Vec<T>>> {
169 let grouping_sets_size = left.len() * right.len();
170
171 check_grouping_sets_size_limit(grouping_sets_size)?;
172
173 let mut result = Vec::with_capacity(grouping_sets_size);
174 for le in left {
175 for re in right {
176 result.push(merge_grouping_set(le, re)?);
177 }
178 }
179 Ok(result)
180}
181
182pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
203 let has_grouping_set = group_expr
204 .iter()
205 .any(|expr| matches!(expr, Expr::GroupingSet(_)));
206 if !has_grouping_set || group_expr.len() == 1 {
207 return Ok(group_expr);
208 }
209 let partial_sets = group_expr
211 .iter()
212 .map(|expr| {
213 let exprs = match expr {
214 Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
215 check_grouping_sets_size_limit(grouping_sets.len())?;
216 grouping_sets.iter().map(|e| e.iter().collect()).collect()
217 }
218 Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
219 let grouping_sets = powerset(group_exprs)?;
220 check_grouping_sets_size_limit(grouping_sets.len())?;
221 grouping_sets
222 }
223 Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
224 let size = group_exprs.len();
225 let slice = group_exprs.as_slice();
226 check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
227 (0..(size + 1))
228 .map(|i| slice[0..i].iter().collect())
229 .collect()
230 }
231 expr => vec![vec![expr]],
232 };
233 Ok(exprs)
234 })
235 .collect::<Result<Vec<_>>>()?;
236
237 let grouping_sets = partial_sets
239 .into_iter()
240 .map(Ok)
241 .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
242 .transpose()?
243 .map(|e| {
244 e.into_iter()
245 .map(|e| e.into_iter().cloned().collect())
246 .collect()
247 })
248 .unwrap_or_default();
249
250 Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
251 grouping_sets,
252 ))])
253}
254
255pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
258 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
259 if group_expr.len() > 1 {
260 return plan_err!(
261 "Invalid group by expressions, GroupingSet must be the only expression"
262 );
263 }
264 Ok(grouping_set.distinct_expr())
265 } else {
266 Ok(group_expr
267 .iter()
268 .collect::<IndexSet<_>>()
269 .into_iter()
270 .collect())
271 }
272}
273
274pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
277 expr.apply(|expr| {
278 match expr {
279 Expr::Column(qc) => {
280 accum.insert(qc.clone());
281 }
282 #[expect(deprecated)]
287 Expr::Unnest(_)
288 | Expr::ScalarVariable(_, _)
289 | Expr::Alias(_)
290 | Expr::Literal(_, _)
291 | Expr::BinaryExpr { .. }
292 | Expr::Like { .. }
293 | Expr::SimilarTo { .. }
294 | Expr::Not(_)
295 | Expr::IsNotNull(_)
296 | Expr::IsNull(_)
297 | Expr::IsTrue(_)
298 | Expr::IsFalse(_)
299 | Expr::IsUnknown(_)
300 | Expr::IsNotTrue(_)
301 | Expr::IsNotFalse(_)
302 | Expr::IsNotUnknown(_)
303 | Expr::Negative(_)
304 | Expr::Between { .. }
305 | Expr::Case { .. }
306 | Expr::Cast { .. }
307 | Expr::TryCast { .. }
308 | Expr::ScalarFunction(..)
309 | Expr::WindowFunction { .. }
310 | Expr::AggregateFunction { .. }
311 | Expr::GroupingSet(_)
312 | Expr::InList { .. }
313 | Expr::Exists { .. }
314 | Expr::InSubquery(_)
315 | Expr::SetComparison(_)
316 | Expr::ScalarSubquery(_)
317 | Expr::Wildcard { .. }
318 | Expr::Placeholder(_)
319 | Expr::OuterReferenceColumn { .. }
320 | Expr::HigherOrderFunction(_)
321 | Expr::Lambda(_)
322 | Expr::LambdaVariable(_) => {}
323 }
324 Ok(TreeNodeRecursion::Continue)
325 })
326 .map(|_| ())
327}
328
329fn get_excluded_columns(
332 opt_exclude: Option<&ExcludeSelectItem>,
333 opt_except: Option<&ExceptSelectItem>,
334 schema: &DFSchema,
335 qualifier: Option<&TableReference>,
336) -> Result<Vec<Column>> {
337 let mut idents = vec![];
338 if let Some(excepts) = opt_except {
339 idents.push(&excepts.first_element);
340 idents.extend(&excepts.additional_elements);
341 }
342 let exclude_owned: Vec<Ident>;
345 if let Some(exclude) = opt_exclude {
346 let object_name_to_ident = |name: &ObjectName| -> Result<Ident> {
347 if name.0.len() != 1 {
348 return plan_err!(
349 "EXCLUDE with multi-part identifiers is not supported: {name}"
350 );
351 }
352 let part = &name.0[0];
353 let Some(ident) = part.as_ident() else {
354 return plan_err!(
355 "EXCLUDE with non-identifier name part is not supported: {part}"
356 );
357 };
358 Ok(ident.clone())
359 };
360 exclude_owned = match exclude {
361 ExcludeSelectItem::Single(name) => vec![object_name_to_ident(name)?],
362 ExcludeSelectItem::Multiple(names) => names
363 .iter()
364 .map(object_name_to_ident)
365 .collect::<Result<Vec<_>>>()?,
366 };
367 idents.extend(exclude_owned.iter());
368 }
369 let n_elem = idents.len();
371 let unique_idents = idents.into_iter().collect::<HashSet<_>>();
372 if n_elem != unique_idents.len() {
375 return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
376 }
377
378 let mut result = vec![];
379 for ident in unique_idents.into_iter() {
380 let col_name = ident.value.as_str();
381 let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
382 result.push(Column::from((qualifier, field)));
383 }
384 Ok(result)
385}
386
387fn get_exprs_except_skipped(
389 schema: &DFSchema,
390 columns_to_skip: &HashSet<Column>,
391) -> Vec<Expr> {
392 if columns_to_skip.is_empty() {
393 schema.iter().map(Expr::from).collect::<Vec<Expr>>()
394 } else {
395 schema
396 .columns()
397 .iter()
398 .filter_map(|c| {
399 if !columns_to_skip.contains(c) {
400 Some(Expr::Column(c.clone()))
401 } else {
402 None
403 }
404 })
405 .collect::<Vec<Expr>>()
406 }
407}
408
409fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
414 let output_columns: HashSet<_> = plan.schema().columns().iter().cloned().collect();
415 let mut excluded = HashSet::new();
416 for cols in plan.using_columns()? {
417 let mut cols: Vec<_> = cols
423 .into_iter()
424 .filter(|c| output_columns.contains(c))
425 .collect();
426
427 cols.sort();
430
431 let mut seen_names = HashSet::new();
434 for col in cols {
435 if seen_names.contains(col.name.as_str()) {
436 excluded.insert(col); } else {
438 seen_names.insert(col.name.clone()); }
440 }
441 }
442 Ok(excluded)
443}
444
445pub fn expand_wildcard(
447 schema: &DFSchema,
448 plan: &LogicalPlan,
449 wildcard_options: Option<&WildcardOptions>,
450) -> Result<Vec<Expr>> {
451 let mut columns_to_skip = exclude_using_columns(plan)?;
452 let excluded_columns = if let Some(WildcardOptions {
453 exclude: opt_exclude,
454 except: opt_except,
455 ..
456 }) = wildcard_options
457 {
458 get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
459 } else {
460 vec![]
461 };
462 columns_to_skip.extend(excluded_columns);
464 Ok(get_exprs_except_skipped(schema, &columns_to_skip))
465}
466
467pub fn expand_qualified_wildcard(
469 qualifier: &TableReference,
470 schema: &DFSchema,
471 wildcard_options: Option<&WildcardOptions>,
472) -> Result<Vec<Expr>> {
473 let qualified_indices = schema.fields_indices_with_qualified(qualifier);
474 let projected_func_dependencies = schema
475 .functional_dependencies()
476 .project_functional_dependencies(&qualified_indices, qualified_indices.len());
477 let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
478 if fields_with_qualified.is_empty() {
479 return plan_err!("Invalid qualifier {qualifier}");
480 }
481
482 let qualified_schema = Arc::new(Schema::new_with_metadata(
483 fields_with_qualified,
484 schema.metadata().clone(),
485 ));
486 let qualified_dfschema =
487 DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
488 .with_functional_dependencies(projected_func_dependencies)?;
489 let excluded_columns = if let Some(WildcardOptions {
490 exclude: opt_exclude,
491 except: opt_except,
492 ..
493 }) = wildcard_options
494 {
495 get_excluded_columns(
496 opt_exclude.as_ref(),
497 opt_except.as_ref(),
498 schema,
499 Some(qualifier),
500 )?
501 } else {
502 vec![]
503 };
504 let mut columns_to_skip = HashSet::new();
506 columns_to_skip.extend(excluded_columns);
507 Ok(get_exprs_except_skipped(
508 &qualified_dfschema,
509 &columns_to_skip,
510 ))
511}
512
513type WindowSortKey = Vec<(Sort, bool)>;
516
517pub fn generate_sort_key(
519 partition_by: &[Expr],
520 order_by: &[Sort],
521) -> Result<WindowSortKey> {
522 let normalized_order_by_keys = order_by
523 .iter()
524 .map(|e| {
525 let Sort { expr, .. } = e;
526 Sort::new(expr.clone(), true, false)
527 })
528 .collect::<Vec<_>>();
529
530 let mut final_sort_keys = vec![];
531 let mut is_partition_flag = vec![];
532 partition_by.iter().for_each(|e| {
533 let e = e.clone().sort(true, false);
536 if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
537 let order_by_key = &order_by[pos];
538 if !final_sort_keys.contains(order_by_key) {
539 final_sort_keys.push(order_by_key.clone());
540 is_partition_flag.push(true);
541 }
542 } else if !final_sort_keys.contains(&e) {
543 final_sort_keys.push(e);
544 is_partition_flag.push(true);
545 }
546 });
547
548 order_by.iter().for_each(|e| {
549 if !final_sort_keys.contains(e) {
550 final_sort_keys.push(e.clone());
551 is_partition_flag.push(false);
552 }
553 });
554 let res = final_sort_keys
555 .into_iter()
556 .zip(is_partition_flag)
557 .collect::<Vec<_>>();
558 Ok(res)
559}
560
561pub fn compare_sort_expr(
564 sort_expr_a: &Sort,
565 sort_expr_b: &Sort,
566 schema: &DFSchemaRef,
567) -> Ordering {
568 let Sort {
569 expr: expr_a,
570 asc: asc_a,
571 nulls_first: nulls_first_a,
572 } = sort_expr_a;
573
574 let Sort {
575 expr: expr_b,
576 asc: asc_b,
577 nulls_first: nulls_first_b,
578 } = sort_expr_b;
579
580 let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
581 let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
582 for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
583 match idx_a.cmp(idx_b) {
584 Ordering::Less => {
585 return Ordering::Less;
586 }
587 Ordering::Greater => {
588 return Ordering::Greater;
589 }
590 Ordering::Equal => {}
591 }
592 }
593 match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
594 Ordering::Less => return Ordering::Greater,
595 Ordering::Greater => {
596 return Ordering::Less;
597 }
598 Ordering::Equal => {}
599 }
600 match (asc_a, asc_b) {
601 (true, false) => {
602 return Ordering::Greater;
603 }
604 (false, true) => {
605 return Ordering::Less;
606 }
607 _ => {}
608 }
609 match (nulls_first_a, nulls_first_b) {
610 (true, false) => {
611 return Ordering::Less;
612 }
613 (false, true) => {
614 return Ordering::Greater;
615 }
616 _ => {}
617 }
618 Ordering::Equal
619}
620
621pub fn group_window_expr_by_sort_keys(
623 window_expr: impl IntoIterator<Item = Expr>,
624) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
625 let mut result = vec![];
626 window_expr.into_iter().try_for_each(|expr| match &expr {
627 Expr::WindowFunction(window_fun) => {
628 let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
629 let sort_key = generate_sort_key(partition_by, order_by)?;
630 if let Some((_, values)) = result.iter_mut().find(
631 |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
632 ) {
633 values.push(expr);
634 } else {
635 result.push((sort_key, vec![expr]))
636 }
637 Ok(())
638 }
639 other => internal_err!(
640 "Impossibly got non-window expr {other:?}"
641 ),
642 })?;
643 Ok(result)
644}
645
646pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
650 find_exprs_in_exprs(exprs, &|nested_expr| {
651 matches!(nested_expr, Expr::AggregateFunction { .. })
652 })
653}
654
655pub fn find_window_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
658 find_exprs_in_exprs(exprs, &|nested_expr| {
659 matches!(nested_expr, Expr::WindowFunction { .. })
660 })
661}
662
663pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
666 find_exprs_in_expr(expr, &|nested_expr| {
667 matches!(nested_expr, Expr::OuterReferenceColumn { .. })
668 })
669}
670
671fn find_exprs_in_exprs<'a, F>(
675 exprs: impl IntoIterator<Item = &'a Expr>,
676 test_fn: &F,
677) -> Vec<Expr>
678where
679 F: Fn(&Expr) -> bool,
680{
681 exprs
682 .into_iter()
683 .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
684 .fold(vec![], |mut acc, expr| {
685 if !acc.contains(&expr) {
686 acc.push(expr)
687 }
688 acc
689 })
690}
691
692fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
696where
697 F: Fn(&Expr) -> bool,
698{
699 let mut exprs = vec![];
700 expr.apply(|expr| {
701 if test_fn(expr) {
702 if !(exprs.contains(expr)) {
703 exprs.push(expr.clone())
704 }
705 return Ok(TreeNodeRecursion::Jump);
707 }
708
709 Ok(TreeNodeRecursion::Continue)
710 })
711 .expect("no way to return error during recursion");
713 exprs
714}
715
716pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
718where
719 F: FnMut(&Expr) -> Result<(), E>,
720{
721 let mut err = Ok(());
722 expr.apply(|expr| {
723 if let Err(e) = f(expr) {
724 err = Err(e);
726 Ok(TreeNodeRecursion::Stop)
727 } else {
728 Ok(TreeNodeRecursion::Continue)
730 }
731 })
732 .expect("no way to return error during recursion");
734
735 err
736}
737
738pub fn exprlist_to_fields<'a>(
756 exprs: impl IntoIterator<Item = &'a Expr>,
757 plan: &LogicalPlan,
758) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
759 let input_schema = plan.schema();
761 exprs
762 .into_iter()
763 .map(|e| e.to_field(input_schema))
764 .collect()
765}
766
767pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
783 let output_exprs = match input.columnized_output_exprs() {
784 Ok(exprs) if !exprs.is_empty() => exprs,
785 _ => return Ok(e),
786 };
787 let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
788 e.transform_down(|node: Expr| match exprs_map.get(&node) {
789 Some(column) => Ok(Transformed::new(
790 Expr::Column(column.clone()),
791 true,
792 TreeNodeRecursion::Jump,
793 )),
794 None => Ok(Transformed::no(node)),
795 })
796 .data()
797}
798
799pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
802 exprs
803 .iter()
804 .flat_map(find_columns_referenced_by_expr)
805 .map(Expr::Column)
806 .collect()
807}
808
809pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
810 let mut exprs = vec![];
811 e.apply(|expr| {
812 if let Expr::Column(c) = expr {
813 exprs.push(c.clone())
814 }
815 Ok(TreeNodeRecursion::Continue)
816 })
817 .expect("Unexpected error");
819 exprs
820}
821
822pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
824 match expr {
825 Expr::Column(col) => {
826 let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
827 Ok(Expr::from(Column::from((qualifier, field))))
828 }
829 _ => Ok(Expr::Column(Column::from_name(
830 expr.schema_name().to_string(),
831 ))),
832 }
833}
834
835pub(crate) fn find_column_indexes_referenced_by_expr(
838 e: &Expr,
839 schema: &DFSchemaRef,
840) -> Vec<usize> {
841 let mut indexes = vec![];
842 e.apply(|expr| {
843 match expr {
844 Expr::Column(qc) => {
845 if let Ok(idx) = schema.index_of_column(qc) {
846 indexes.push(idx);
847 }
848 }
849 Expr::Literal(_, _) => {
850 indexes.push(usize::MAX);
851 }
852 _ => {}
853 }
854 Ok(TreeNodeRecursion::Continue)
855 })
856 .unwrap();
857 indexes
858}
859
860pub fn can_hash(data_type: &DataType) -> bool {
864 match data_type {
865 DataType::Null => true,
866 DataType::Boolean => true,
867 DataType::Int8 => true,
868 DataType::Int16 => true,
869 DataType::Int32 => true,
870 DataType::Int64 => true,
871 DataType::UInt8 => true,
872 DataType::UInt16 => true,
873 DataType::UInt32 => true,
874 DataType::UInt64 => true,
875 DataType::Float16 => true,
876 DataType::Float32 => true,
877 DataType::Float64 => true,
878 DataType::Decimal32(_, _) => true,
879 DataType::Decimal64(_, _) => true,
880 DataType::Decimal128(_, _) => true,
881 DataType::Decimal256(_, _) => true,
882 DataType::Timestamp(_, _) => true,
883 DataType::Utf8 => true,
884 DataType::LargeUtf8 => true,
885 DataType::Utf8View => true,
886 DataType::Binary => true,
887 DataType::LargeBinary => true,
888 DataType::BinaryView => true,
889 DataType::Date32 => true,
890 DataType::Date64 => true,
891 DataType::Time32(_) => true,
892 DataType::Time64(_) => true,
893 DataType::Duration(_) => true,
894 DataType::Interval(_) => true,
895 DataType::FixedSizeBinary(_) => true,
896 DataType::Dictionary(key_type, value_type) => {
897 DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
898 }
899 DataType::List(value_type) => can_hash(value_type.data_type()),
900 DataType::LargeList(value_type) => can_hash(value_type.data_type()),
901 DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
902 DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
903 DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
904
905 DataType::ListView(_)
906 | DataType::LargeListView(_)
907 | DataType::Union(_, _)
908 | DataType::RunEndEncoded(_, _) => false,
909 }
910}
911
912pub fn check_all_columns_from_schema(
914 columns: &HashSet<&Column>,
915 schema: &DFSchema,
916) -> Result<bool> {
917 for col in columns.iter() {
918 let exist = schema.is_column_from_schema(col);
919 if !exist {
920 return Ok(false);
921 }
922 }
923
924 Ok(true)
925}
926
927pub fn find_valid_equijoin_key_pair(
936 left_key: &Expr,
937 right_key: &Expr,
938 left_schema: &DFSchema,
939 right_schema: &DFSchema,
940) -> Result<Option<(Expr, Expr)>> {
941 let left_using_columns = left_key.column_refs();
942 let right_using_columns = right_key.column_refs();
943
944 if left_using_columns.is_empty() || right_using_columns.is_empty() {
946 return Ok(None);
947 }
948
949 if check_all_columns_from_schema(&left_using_columns, left_schema)?
950 && check_all_columns_from_schema(&right_using_columns, right_schema)?
951 {
952 return Ok(Some((left_key.clone(), right_key.clone())));
953 } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
954 && check_all_columns_from_schema(&left_using_columns, right_schema)?
955 {
956 return Ok(Some((right_key.clone(), left_key.clone())));
957 }
958
959 Ok(None)
960}
961
962#[expect(clippy::needless_pass_by_value)]
974#[deprecated(since = "53.0.0", note = "Internal function")]
975pub fn generate_signature_error_msg(
976 func_name: &str,
977 func_signature: Signature,
978 input_expr_types: &[DataType],
979) -> String {
980 let candidate_signatures = func_signature
981 .type_signature
982 .to_string_repr_with_names(func_signature.parameter_names.as_deref())
983 .iter()
984 .map(|args_str| format!("\t{func_name}({args_str})"))
985 .collect::<Vec<String>>()
986 .join("\n");
987
988 format!(
989 "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
990 func_name,
991 TypeSignature::join_types(input_expr_types, ", "),
992 candidate_signatures
993 )
994}
995
996pub(crate) fn generate_signature_error_message(
1008 func_name: &str,
1009 func_signature: &Signature,
1010 input_expr_types: &[DataType],
1011) -> String {
1012 #[expect(deprecated)]
1013 generate_signature_error_msg(func_name, func_signature.clone(), input_expr_types)
1014}
1015
1016pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
1020 split_conjunction_impl(expr, vec![])
1021}
1022
1023fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
1024 match expr {
1025 Expr::BinaryExpr(BinaryExpr {
1026 right,
1027 op: Operator::And,
1028 left,
1029 }) => {
1030 let exprs = split_conjunction_impl(left, exprs);
1031 split_conjunction_impl(right, exprs)
1032 }
1033 Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
1034 other => {
1035 exprs.push(other);
1036 exprs
1037 }
1038 }
1039}
1040
1041pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
1045 let mut stack = vec![expr];
1046 std::iter::from_fn(move || {
1047 while let Some(expr) = stack.pop() {
1048 match expr {
1049 Expr::BinaryExpr(BinaryExpr {
1050 right,
1051 op: Operator::And,
1052 left,
1053 }) => {
1054 stack.push(right);
1055 stack.push(left);
1056 }
1057 Expr::Alias(Alias { expr, .. }) => stack.push(expr),
1058 other => return Some(other),
1059 }
1060 }
1061 None
1062 })
1063}
1064
1065pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
1069 let mut stack = vec![expr];
1070 std::iter::from_fn(move || {
1071 while let Some(expr) = stack.pop() {
1072 match expr {
1073 Expr::BinaryExpr(BinaryExpr {
1074 right,
1075 op: Operator::And,
1076 left,
1077 }) => {
1078 stack.push(*right);
1079 stack.push(*left);
1080 }
1081 Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
1082 other => return Some(other),
1083 }
1084 }
1085 None
1086 })
1087}
1088
1089pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1108 split_binary_owned(expr, Operator::And)
1109}
1110
1111pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1131 split_binary_owned_impl(expr, op, vec![])
1132}
1133
1134fn split_binary_owned_impl(
1135 expr: Expr,
1136 operator: Operator,
1137 mut exprs: Vec<Expr>,
1138) -> Vec<Expr> {
1139 match expr {
1140 Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1141 let exprs = split_binary_owned_impl(*left, operator, exprs);
1142 split_binary_owned_impl(*right, operator, exprs)
1143 }
1144 Expr::Alias(Alias { expr, .. }) => {
1145 split_binary_owned_impl(*expr, operator, exprs)
1146 }
1147 other => {
1148 exprs.push(other);
1149 exprs
1150 }
1151 }
1152}
1153
1154pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1158 split_binary_impl(expr, op, vec![])
1159}
1160
1161fn split_binary_impl<'a>(
1162 expr: &'a Expr,
1163 operator: Operator,
1164 mut exprs: Vec<&'a Expr>,
1165) -> Vec<&'a Expr> {
1166 match expr {
1167 Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1168 let exprs = split_binary_impl(left, operator, exprs);
1169 split_binary_impl(right, operator, exprs)
1170 }
1171 Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1172 other => {
1173 exprs.push(other);
1174 exprs
1175 }
1176 }
1177}
1178
1179pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1199 filters.into_iter().reduce(Expr::and)
1200}
1201
1202pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1222 filters.into_iter().reduce(Expr::or)
1223}
1224
1225pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1240 let predicate = predicates
1242 .iter()
1243 .skip(1)
1244 .fold(predicates[0].clone(), |acc, predicate| {
1245 and(acc, (*predicate).to_owned())
1246 });
1247
1248 Ok(LogicalPlan::Filter(Filter::try_new(
1249 predicate,
1250 Arc::new(plan),
1251 )?))
1252}
1253
1254pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1265 let mut joins = vec![];
1266 let mut others = vec![];
1267 for filter in exprs.into_iter() {
1268 if filter.contains_outer() {
1270 if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1271 {
1272 joins.push(strip_outer_reference((*filter).clone()));
1273 }
1274 } else {
1275 others.push((*filter).clone());
1276 }
1277 }
1278
1279 Ok((joins, others))
1280}
1281
1282pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1292 match slice {
1293 [it] => Ok(it),
1294 [] => plan_err!("No items found!"),
1295 _ => plan_err!("More than one item found!"),
1296 }
1297}
1298
1299pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1304 if inputs.len() == 1 {
1305 inputs[0].schema().as_ref().clone()
1306 } else {
1307 inputs.iter().map(|input| input.schema()).fold(
1308 DFSchema::empty(),
1309 |mut lhs, rhs| {
1310 lhs.merge(rhs);
1311 lhs
1312 },
1313 )
1314 }
1315}
1316
1317pub fn format_state_name(name: &str, state_name: &str) -> String {
1319 format!("{name}[{state_name}]")
1320}
1321
1322pub fn collect_subquery_cols(
1324 exprs: &[Expr],
1325 subquery_schema: &DFSchema,
1326) -> Result<BTreeSet<Column>> {
1327 exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1328 let mut using_cols: Vec<Column> = vec![];
1329 for col in expr.column_refs().into_iter() {
1330 if subquery_schema.has_column(col) {
1331 using_cols.push(col.clone());
1332 }
1333 }
1334
1335 cols.extend(using_cols);
1336 Result::<_>::Ok(cols)
1337 })
1338}
1339
1340#[cfg(test)]
1341mod tests {
1342 use super::*;
1343 use crate::{
1344 Cast, ExprFunctionExt, WindowFunctionDefinition, col, cube,
1345 expr::WindowFunction,
1346 expr_vec_fmt, grouping_set, lit, rollup,
1347 test::function_stub::{max_udaf, min_udaf, sum_udaf},
1348 };
1349 use arrow::datatypes::{UnionFields, UnionMode};
1350 use datafusion_expr_common::signature::Volatility;
1351
1352 #[test]
1353 fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1354 let result = group_window_expr_by_sort_keys(vec![])?;
1355 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1356 assert_eq!(expected, result);
1357 Ok(())
1358 }
1359
1360 #[test]
1361 fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1362 let max1 = Expr::from(WindowFunction::new(
1363 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1364 vec![col("name")],
1365 ));
1366 let max2 = Expr::from(WindowFunction::new(
1367 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1368 vec![col("name")],
1369 ));
1370 let min3 = Expr::from(WindowFunction::new(
1371 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1372 vec![col("name")],
1373 ));
1374 let sum4 = Expr::from(WindowFunction::new(
1375 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1376 vec![col("age")],
1377 ));
1378 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1379 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1380 let key = vec![];
1381 let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1382 vec![(key, vec![max1, max2, min3, sum4])];
1383 assert_eq!(expected, result);
1384 Ok(())
1385 }
1386
1387 #[test]
1388 fn test_group_window_expr_by_sort_keys() -> Result<()> {
1389 let age_asc = Sort::new(col("age"), true, true);
1390 let name_desc = Sort::new(col("name"), false, true);
1391 let created_at_desc = Sort::new(col("created_at"), false, true);
1392 let max1 = Expr::from(WindowFunction::new(
1393 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1394 vec![col("name")],
1395 ))
1396 .order_by(vec![age_asc.clone(), name_desc.clone()])
1397 .build()
1398 .unwrap();
1399 let max2 = Expr::from(WindowFunction::new(
1400 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1401 vec![col("name")],
1402 ));
1403 let min3 = Expr::from(WindowFunction::new(
1404 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1405 vec![col("name")],
1406 ))
1407 .order_by(vec![age_asc.clone(), name_desc.clone()])
1408 .build()
1409 .unwrap();
1410 let sum4 = Expr::from(WindowFunction::new(
1411 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1412 vec![col("age")],
1413 ))
1414 .order_by(vec![
1415 name_desc.clone(),
1416 age_asc.clone(),
1417 created_at_desc.clone(),
1418 ])
1419 .build()
1420 .unwrap();
1421 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1423 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1424
1425 let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1426 let key2 = vec![];
1427 let key3 = vec![
1428 (name_desc, false),
1429 (age_asc, false),
1430 (created_at_desc, false),
1431 ];
1432
1433 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1434 (key1, vec![max1, min3]),
1435 (key2, vec![max2]),
1436 (key3, vec![sum4]),
1437 ];
1438 assert_eq!(expected, result);
1439 Ok(())
1440 }
1441
1442 #[test]
1443 fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1444 let asc_or_desc = [true, false];
1445 let nulls_first_or_last = [true, false];
1446 let partition_by = &[col("age"), col("name"), col("created_at")];
1447 for asc_ in asc_or_desc {
1448 for nulls_first_ in nulls_first_or_last {
1449 let order_by = &[
1450 Sort {
1451 expr: col("age"),
1452 asc: asc_,
1453 nulls_first: nulls_first_,
1454 },
1455 Sort {
1456 expr: col("name"),
1457 asc: asc_,
1458 nulls_first: nulls_first_,
1459 },
1460 ];
1461
1462 let expected = vec![
1463 (
1464 Sort {
1465 expr: col("age"),
1466 asc: asc_,
1467 nulls_first: nulls_first_,
1468 },
1469 true,
1470 ),
1471 (
1472 Sort {
1473 expr: col("name"),
1474 asc: asc_,
1475 nulls_first: nulls_first_,
1476 },
1477 true,
1478 ),
1479 (
1480 Sort {
1481 expr: col("created_at"),
1482 asc: true,
1483 nulls_first: false,
1484 },
1485 true,
1486 ),
1487 ];
1488 let result = generate_sort_key(partition_by, order_by)?;
1489 assert_eq!(expected, result);
1490 }
1491 }
1492 Ok(())
1493 }
1494
1495 #[test]
1496 fn test_enumerate_grouping_sets() -> Result<()> {
1497 let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1498 let simple_col = col("simple_col");
1499 let cube = cube(multi_cols.clone());
1500 let rollup = rollup(multi_cols.clone());
1501 let grouping_set = grouping_set(vec![multi_cols]);
1502
1503 let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1505 let result = format!("[{}]", expr_vec_fmt!(sets));
1506 assert_eq!("[simple_col]", &result);
1507
1508 let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1510 let result = format!("[{}]", expr_vec_fmt!(sets));
1511 assert_eq!("[CUBE (col1, col2, col3)]", &result);
1512
1513 let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1515 let result = format!("[{}]", expr_vec_fmt!(sets));
1516 assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1517
1518 let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1520 let result = format!("[{}]", expr_vec_fmt!(sets));
1521 assert_eq!(
1522 "[GROUPING SETS (\
1523 (simple_col), \
1524 (simple_col, col1), \
1525 (simple_col, col2), \
1526 (simple_col, col1, col2), \
1527 (simple_col, col3), \
1528 (simple_col, col1, col3), \
1529 (simple_col, col2, col3), \
1530 (simple_col, col1, col2, col3))]",
1531 &result
1532 );
1533
1534 let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1536 let result = format!("[{}]", expr_vec_fmt!(sets));
1537 assert_eq!(
1538 "[GROUPING SETS (\
1539 (simple_col), \
1540 (simple_col, col1), \
1541 (simple_col, col1, col2), \
1542 (simple_col, col1, col2, col3))]",
1543 &result
1544 );
1545
1546 let sets =
1548 enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1549 let result = format!("[{}]", expr_vec_fmt!(sets));
1550 assert_eq!(
1551 "[GROUPING SETS (\
1552 (simple_col, col1, col2, col3))]",
1553 &result
1554 );
1555
1556 let sets = enumerate_grouping_sets(vec![
1558 simple_col.clone(),
1559 grouping_set,
1560 rollup.clone(),
1561 ])?;
1562 let result = format!("[{}]", expr_vec_fmt!(sets));
1563 assert_eq!(
1564 "[GROUPING SETS (\
1565 (simple_col, col1, col2, col3), \
1566 (simple_col, col1, col2, col3, col1), \
1567 (simple_col, col1, col2, col3, col1, col2), \
1568 (simple_col, col1, col2, col3, col1, col2, col3))]",
1569 &result
1570 );
1571
1572 let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1574 let result = format!("[{}]", expr_vec_fmt!(sets));
1575 assert_eq!(
1576 "[GROUPING SETS (\
1577 (simple_col), \
1578 (simple_col, col1), \
1579 (simple_col, col1, col2), \
1580 (simple_col, col1, col2, col3), \
1581 (simple_col, col1), \
1582 (simple_col, col1, col1), \
1583 (simple_col, col1, col1, col2), \
1584 (simple_col, col1, col1, col2, col3), \
1585 (simple_col, col2), \
1586 (simple_col, col2, col1), \
1587 (simple_col, col2, col1, col2), \
1588 (simple_col, col2, col1, col2, col3), \
1589 (simple_col, col1, col2), \
1590 (simple_col, col1, col2, col1), \
1591 (simple_col, col1, col2, col1, col2), \
1592 (simple_col, col1, col2, col1, col2, col3), \
1593 (simple_col, col3), \
1594 (simple_col, col3, col1), \
1595 (simple_col, col3, col1, col2), \
1596 (simple_col, col3, col1, col2, col3), \
1597 (simple_col, col1, col3), \
1598 (simple_col, col1, col3, col1), \
1599 (simple_col, col1, col3, col1, col2), \
1600 (simple_col, col1, col3, col1, col2, col3), \
1601 (simple_col, col2, col3), \
1602 (simple_col, col2, col3, col1), \
1603 (simple_col, col2, col3, col1, col2), \
1604 (simple_col, col2, col3, col1, col2, col3), \
1605 (simple_col, col1, col2, col3), \
1606 (simple_col, col1, col2, col3, col1), \
1607 (simple_col, col1, col2, col3, col1, col2), \
1608 (simple_col, col1, col2, col3, col1, col2, col3))]",
1609 &result
1610 );
1611
1612 Ok(())
1613 }
1614 #[test]
1615 fn test_split_conjunction() {
1616 let expr = col("a");
1617 let result = split_conjunction(&expr);
1618 assert_eq!(result, vec![&expr]);
1619 }
1620
1621 #[test]
1622 fn test_split_conjunction_two() {
1623 let expr = col("a").eq(lit(5)).and(col("b"));
1624 let expr1 = col("a").eq(lit(5));
1625 let expr2 = col("b");
1626
1627 let result = split_conjunction(&expr);
1628 assert_eq!(result, vec![&expr1, &expr2]);
1629 }
1630
1631 #[test]
1632 fn test_split_conjunction_alias() {
1633 let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1634 let expr1 = col("a").eq(lit(5));
1635 let expr2 = col("b"); let result = split_conjunction(&expr);
1638 assert_eq!(result, vec![&expr1, &expr2]);
1639 }
1640
1641 #[test]
1642 fn test_split_conjunction_or() {
1643 let expr = col("a").eq(lit(5)).or(col("b"));
1644 let result = split_conjunction(&expr);
1645 assert_eq!(result, vec![&expr]);
1646 }
1647
1648 #[test]
1649 fn test_split_binary_owned() {
1650 let expr = col("a");
1651 assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1652 }
1653
1654 #[test]
1655 fn test_split_binary_owned_two() {
1656 assert_eq!(
1657 split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1658 vec![col("a").eq(lit(5)), col("b")]
1659 );
1660 }
1661
1662 #[test]
1663 fn test_split_binary_owned_different_op() {
1664 let expr = col("a").eq(lit(5)).or(col("b"));
1665 assert_eq!(
1666 split_binary_owned(expr.clone(), Operator::And),
1668 vec![expr]
1669 );
1670 }
1671
1672 #[test]
1673 fn test_split_conjunction_owned() {
1674 let expr = col("a");
1675 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1676 }
1677
1678 #[test]
1679 fn test_split_conjunction_owned_two() {
1680 assert_eq!(
1681 split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1682 vec![col("a").eq(lit(5)), col("b")]
1683 );
1684 }
1685
1686 #[test]
1687 fn test_split_conjunction_owned_alias() {
1688 assert_eq!(
1689 split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1690 vec![
1691 col("a").eq(lit(5)),
1692 col("b"),
1694 ]
1695 );
1696 }
1697
1698 #[test]
1699 fn test_conjunction_empty() {
1700 assert_eq!(conjunction(vec![]), None);
1701 }
1702
1703 #[test]
1704 fn test_conjunction() {
1705 let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1707
1708 assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1710
1711 assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1713 }
1714
1715 #[test]
1716 fn test_disjunction_empty() {
1717 assert_eq!(disjunction(vec![]), None);
1718 }
1719
1720 #[test]
1721 fn test_disjunction() {
1722 let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1724
1725 assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1727
1728 assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1730 }
1731
1732 #[test]
1733 fn test_split_conjunction_owned_or() {
1734 let expr = col("a").eq(lit(5)).or(col("b"));
1735 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1736 }
1737
1738 #[test]
1739 fn test_collect_expr() -> Result<()> {
1740 let mut accum: HashSet<Column> = HashSet::new();
1741 expr_to_columns(
1742 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1743 &mut accum,
1744 )?;
1745 expr_to_columns(
1746 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1747 &mut accum,
1748 )?;
1749 assert_eq!(1, accum.len());
1750 assert!(accum.contains(&Column::from_name("a")));
1751 Ok(())
1752 }
1753
1754 #[test]
1755 fn test_can_hash() {
1756 let union_fields: UnionFields = [
1757 (0, Arc::new(Field::new("A", DataType::Int32, true))),
1758 (1, Arc::new(Field::new("B", DataType::Float64, true))),
1759 ]
1760 .into_iter()
1761 .collect();
1762
1763 let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1764 assert!(!can_hash(&union_type));
1765
1766 let list_union_type =
1767 DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1768 assert!(!can_hash(&list_union_type));
1769 }
1770
1771 #[test]
1772 fn test_generate_signature_error_msg_with_parameter_names() {
1773 let sig = Signature::one_of(
1774 vec![
1775 TypeSignature::Exact(vec![DataType::Utf8, DataType::Int64]),
1776 TypeSignature::Exact(vec![
1777 DataType::Utf8,
1778 DataType::Int64,
1779 DataType::Int64,
1780 ]),
1781 ],
1782 Volatility::Immutable,
1783 )
1784 .with_parameter_names(vec![
1785 "str".to_string(),
1786 "start_pos".to_string(),
1787 "length".to_string(),
1788 ])
1789 .expect("valid parameter names");
1790
1791 let error_msg =
1793 generate_signature_error_message("substr", &sig, &[DataType::Utf8]);
1794
1795 assert!(
1796 error_msg.contains("str: Utf8, start_pos: Int64"),
1797 "Expected 'str: Utf8, start_pos: Int64' in error message, got: {error_msg}"
1798 );
1799 assert!(
1800 error_msg.contains("str: Utf8, start_pos: Int64, length: Int64"),
1801 "Expected 'str: Utf8, start_pos: Int64, length: Int64' in error message, got: {error_msg}"
1802 );
1803 }
1804
1805 #[test]
1806 fn test_generate_signature_error_msg_without_parameter_names() {
1807 let sig = Signature::one_of(
1808 vec![TypeSignature::Any(2), TypeSignature::Any(3)],
1809 Volatility::Immutable,
1810 );
1811
1812 let error_msg =
1813 generate_signature_error_message("my_func", &sig, &[DataType::Int32]);
1814
1815 assert!(
1816 error_msg.contains("Any, Any"),
1817 "Expected 'Any, Any' without parameter names, got: {error_msg}"
1818 );
1819 }
1820
1821 #[test]
1822 fn test_signature_error_msg_exact() {
1823 use insta::assert_snapshot;
1824
1825 let sig = Signature::one_of(
1826 vec![
1827 TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]),
1828 TypeSignature::Exact(vec![DataType::Float32, DataType::Int64]),
1829 TypeSignature::Exact(vec![DataType::Float64]),
1830 TypeSignature::Exact(vec![DataType::Float32]),
1831 ],
1832 Volatility::Immutable,
1833 );
1834 let msg = generate_signature_error_message(
1835 "round",
1836 &sig,
1837 &[DataType::Float64, DataType::Float64],
1838 );
1839 assert_snapshot!(msg, @r"
1840 No function matches the given name and argument types 'round(Float64, Float64)'. You might need to add explicit type casts.
1841 Candidate functions:
1842 round(Float64, Int64)
1843 round(Float32, Int64)
1844 round(Float64)
1845 round(Float32)
1846 ");
1847 }
1848
1849 #[test]
1850 fn test_signature_error_msg_coercible() {
1851 use datafusion_common::types::NativeType;
1852 use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
1853 use insta::assert_snapshot;
1854
1855 let sig = Signature::coercible(
1856 vec![
1857 Coercion::new_implicit(
1858 TypeSignatureClass::Native(
1859 datafusion_common::types::logical_float64(),
1860 ),
1861 vec![TypeSignatureClass::Numeric],
1862 NativeType::Float64,
1863 ),
1864 Coercion::new_implicit(
1865 TypeSignatureClass::Native(datafusion_common::types::logical_int64()),
1866 vec![TypeSignatureClass::Integer],
1867 NativeType::Int64,
1868 ),
1869 ],
1870 Volatility::Immutable,
1871 );
1872 let msg = generate_signature_error_message(
1873 "round",
1874 &sig,
1875 &[DataType::Utf8, DataType::Utf8],
1876 );
1877 assert_snapshot!(msg, @r"
1878 No function matches the given name and argument types 'round(Utf8, Utf8)'. You might need to add explicit type casts.
1879 Candidate functions:
1880 round(Float64, Int64)
1881 ");
1882 }
1883
1884 #[test]
1885 fn test_signature_error_msg_with_names_coercible() {
1886 use datafusion_common::types::NativeType;
1887 use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
1888 use insta::assert_snapshot;
1889
1890 let sig = Signature::coercible(
1891 vec![
1892 Coercion::new_exact(TypeSignatureClass::Native(
1893 datafusion_common::types::logical_string(),
1894 )),
1895 Coercion::new_exact(TypeSignatureClass::Native(
1896 datafusion_common::types::logical_int64(),
1897 )),
1898 Coercion::new_implicit(
1899 TypeSignatureClass::Native(datafusion_common::types::logical_int64()),
1900 vec![TypeSignatureClass::Integer],
1901 NativeType::Int64,
1902 ),
1903 ],
1904 Volatility::Immutable,
1905 )
1906 .with_parameter_names(vec![
1907 "string".to_string(),
1908 "start_pos".to_string(),
1909 "length".to_string(),
1910 ])
1911 .expect("valid parameter names");
1912
1913 let msg = generate_signature_error_message("substr", &sig, &[DataType::Int32]);
1914 assert_snapshot!(msg, @r"
1915 No function matches the given name and argument types 'substr(Int32)'. You might need to add explicit type casts.
1916 Candidate functions:
1917 substr(string: String, start_pos: Int64, length: Int64)
1918 ");
1919 }
1920}