1use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27 and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37 internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap,
38 Result, TableReference,
39};
40
41use indexmap::IndexSet;
42use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
43
44pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
45
46pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
49
50pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
53 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
54 if group_expr.len() > 1 {
55 return plan_err!(
56 "Invalid group by expressions, GroupingSet must be the only expression"
57 );
58 }
59 Ok(grouping_set.distinct_expr().len() + 1)
61 } else {
62 grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
63 }
64}
65
66fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
84 if slice.len() >= 64 {
85 return Err("The size of the set must be less than 64.".into());
86 }
87
88 let mut v = Vec::new();
89 for mask in 0..(1 << slice.len()) {
90 let mut ss = vec![];
91 let mut bitset = mask;
92 while bitset > 0 {
93 let rightmost: u64 = bitset & !(bitset - 1);
94 let idx = rightmost.trailing_zeros();
95 let item = slice.get(idx as usize).unwrap();
96 ss.push(item);
97 bitset &= bitset - 1;
99 }
100 v.push(ss);
101 }
102 Ok(v)
103}
104
105fn check_grouping_set_size_limit(size: usize) -> Result<()> {
107 let max_grouping_set_size = 65535;
108 if size > max_grouping_set_size {
109 return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
110 }
111
112 Ok(())
113}
114
115fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
117 let max_grouping_sets_size = 4096;
118 if size > max_grouping_sets_size {
119 return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
120 }
121
122 Ok(())
123}
124
125fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
137 check_grouping_set_size_limit(left.len() + right.len())?;
138 Ok(left.iter().chain(right.iter()).cloned().collect())
139}
140
141fn cross_join_grouping_sets<T: Clone>(
154 left: &[Vec<T>],
155 right: &[Vec<T>],
156) -> Result<Vec<Vec<T>>> {
157 let grouping_sets_size = left.len() * right.len();
158
159 check_grouping_sets_size_limit(grouping_sets_size)?;
160
161 let mut result = Vec::with_capacity(grouping_sets_size);
162 for le in left {
163 for re in right {
164 result.push(merge_grouping_set(le, re)?);
165 }
166 }
167 Ok(result)
168}
169
170pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
191 let has_grouping_set = group_expr
192 .iter()
193 .any(|expr| matches!(expr, Expr::GroupingSet(_)));
194 if !has_grouping_set || group_expr.len() == 1 {
195 return Ok(group_expr);
196 }
197 let partial_sets = group_expr
199 .iter()
200 .map(|expr| {
201 let exprs = match expr {
202 Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
203 check_grouping_sets_size_limit(grouping_sets.len())?;
204 grouping_sets.iter().map(|e| e.iter().collect()).collect()
205 }
206 Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
207 let grouping_sets = powerset(group_exprs)
208 .map_err(|e| plan_datafusion_err!("{}", e))?;
209 check_grouping_sets_size_limit(grouping_sets.len())?;
210 grouping_sets
211 }
212 Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
213 let size = group_exprs.len();
214 let slice = group_exprs.as_slice();
215 check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
216 (0..(size + 1))
217 .map(|i| slice[0..i].iter().collect())
218 .collect()
219 }
220 expr => vec![vec![expr]],
221 };
222 Ok(exprs)
223 })
224 .collect::<Result<Vec<_>>>()?;
225
226 let grouping_sets = partial_sets
228 .into_iter()
229 .map(Ok)
230 .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
231 .transpose()?
232 .map(|e| {
233 e.into_iter()
234 .map(|e| e.into_iter().cloned().collect())
235 .collect()
236 })
237 .unwrap_or_default();
238
239 Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
240 grouping_sets,
241 ))])
242}
243
244pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
247 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
248 if group_expr.len() > 1 {
249 return plan_err!(
250 "Invalid group by expressions, GroupingSet must be the only expression"
251 );
252 }
253 Ok(grouping_set.distinct_expr())
254 } else {
255 Ok(group_expr
256 .iter()
257 .collect::<IndexSet<_>>()
258 .into_iter()
259 .collect())
260 }
261}
262
263pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
266 expr.apply(|expr| {
267 match expr {
268 Expr::Column(qc) => {
269 accum.insert(qc.clone());
270 }
271 #[expect(deprecated)]
276 Expr::Unnest(_)
277 | Expr::ScalarVariable(_, _)
278 | Expr::Alias(_)
279 | Expr::Literal(_, _)
280 | Expr::BinaryExpr { .. }
281 | Expr::Like { .. }
282 | Expr::SimilarTo { .. }
283 | Expr::Not(_)
284 | Expr::IsNotNull(_)
285 | Expr::IsNull(_)
286 | Expr::IsTrue(_)
287 | Expr::IsFalse(_)
288 | Expr::IsUnknown(_)
289 | Expr::IsNotTrue(_)
290 | Expr::IsNotFalse(_)
291 | Expr::IsNotUnknown(_)
292 | Expr::Negative(_)
293 | Expr::Between { .. }
294 | Expr::Case { .. }
295 | Expr::Cast { .. }
296 | Expr::TryCast { .. }
297 | Expr::ScalarFunction(..)
298 | Expr::WindowFunction { .. }
299 | Expr::AggregateFunction { .. }
300 | Expr::GroupingSet(_)
301 | Expr::InList { .. }
302 | Expr::Exists { .. }
303 | Expr::InSubquery(_)
304 | Expr::ScalarSubquery(_)
305 | Expr::Wildcard { .. }
306 | Expr::Placeholder(_)
307 | Expr::OuterReferenceColumn { .. } => {}
308 }
309 Ok(TreeNodeRecursion::Continue)
310 })
311 .map(|_| ())
312}
313
314fn get_excluded_columns(
317 opt_exclude: Option<&ExcludeSelectItem>,
318 opt_except: Option<&ExceptSelectItem>,
319 schema: &DFSchema,
320 qualifier: Option<&TableReference>,
321) -> Result<Vec<Column>> {
322 let mut idents = vec![];
323 if let Some(excepts) = opt_except {
324 idents.push(&excepts.first_element);
325 idents.extend(&excepts.additional_elements);
326 }
327 if let Some(exclude) = opt_exclude {
328 match exclude {
329 ExcludeSelectItem::Single(ident) => idents.push(ident),
330 ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
331 }
332 }
333 let n_elem = idents.len();
335 let unique_idents = idents.into_iter().collect::<HashSet<_>>();
336 if n_elem != unique_idents.len() {
339 return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
340 }
341
342 let mut result = vec![];
343 for ident in unique_idents.into_iter() {
344 let col_name = ident.value.as_str();
345 let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
346 result.push(Column::from((qualifier, field)));
347 }
348 Ok(result)
349}
350
351fn get_exprs_except_skipped(
353 schema: &DFSchema,
354 columns_to_skip: HashSet<Column>,
355) -> Vec<Expr> {
356 if columns_to_skip.is_empty() {
357 schema.iter().map(Expr::from).collect::<Vec<Expr>>()
358 } else {
359 schema
360 .columns()
361 .iter()
362 .filter_map(|c| {
363 if !columns_to_skip.contains(c) {
364 Some(Expr::Column(c.clone()))
365 } else {
366 None
367 }
368 })
369 .collect::<Vec<Expr>>()
370 }
371}
372
373fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
377 let using_columns = plan.using_columns()?;
378 let excluded = using_columns
379 .into_iter()
380 .flat_map(|cols| {
382 let mut cols = cols.into_iter().collect::<Vec<_>>();
383 cols.sort();
386 let mut out_column_names: HashSet<String> = HashSet::new();
387 cols.into_iter().filter_map(move |c| {
388 if out_column_names.contains(&c.name) {
389 Some(c)
390 } else {
391 out_column_names.insert(c.name);
392 None
393 }
394 })
395 })
396 .collect::<HashSet<_>>();
397 Ok(excluded)
398}
399
400pub fn expand_wildcard(
402 schema: &DFSchema,
403 plan: &LogicalPlan,
404 wildcard_options: Option<&WildcardOptions>,
405) -> Result<Vec<Expr>> {
406 let mut columns_to_skip = exclude_using_columns(plan)?;
407 let excluded_columns = if let Some(WildcardOptions {
408 exclude: opt_exclude,
409 except: opt_except,
410 ..
411 }) = wildcard_options
412 {
413 get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
414 } else {
415 vec![]
416 };
417 columns_to_skip.extend(excluded_columns);
419 Ok(get_exprs_except_skipped(schema, columns_to_skip))
420}
421
422pub fn expand_qualified_wildcard(
424 qualifier: &TableReference,
425 schema: &DFSchema,
426 wildcard_options: Option<&WildcardOptions>,
427) -> Result<Vec<Expr>> {
428 let qualified_indices = schema.fields_indices_with_qualified(qualifier);
429 let projected_func_dependencies = schema
430 .functional_dependencies()
431 .project_functional_dependencies(&qualified_indices, qualified_indices.len());
432 let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
433 if fields_with_qualified.is_empty() {
434 return plan_err!("Invalid qualifier {qualifier}");
435 }
436
437 let qualified_schema = Arc::new(Schema::new_with_metadata(
438 fields_with_qualified,
439 schema.metadata().clone(),
440 ));
441 let qualified_dfschema =
442 DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
443 .with_functional_dependencies(projected_func_dependencies)?;
444 let excluded_columns = if let Some(WildcardOptions {
445 exclude: opt_exclude,
446 except: opt_except,
447 ..
448 }) = wildcard_options
449 {
450 get_excluded_columns(
451 opt_exclude.as_ref(),
452 opt_except.as_ref(),
453 schema,
454 Some(qualifier),
455 )?
456 } else {
457 vec![]
458 };
459 let mut columns_to_skip = HashSet::new();
461 columns_to_skip.extend(excluded_columns);
462 Ok(get_exprs_except_skipped(
463 &qualified_dfschema,
464 columns_to_skip,
465 ))
466}
467
468type WindowSortKey = Vec<(Sort, bool)>;
471
472pub fn generate_sort_key(
474 partition_by: &[Expr],
475 order_by: &[Sort],
476) -> Result<WindowSortKey> {
477 let normalized_order_by_keys = order_by
478 .iter()
479 .map(|e| {
480 let Sort { expr, .. } = e;
481 Sort::new(expr.clone(), true, false)
482 })
483 .collect::<Vec<_>>();
484
485 let mut final_sort_keys = vec![];
486 let mut is_partition_flag = vec![];
487 partition_by.iter().for_each(|e| {
488 let e = e.clone().sort(true, false);
491 if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
492 let order_by_key = &order_by[pos];
493 if !final_sort_keys.contains(order_by_key) {
494 final_sort_keys.push(order_by_key.clone());
495 is_partition_flag.push(true);
496 }
497 } else if !final_sort_keys.contains(&e) {
498 final_sort_keys.push(e);
499 is_partition_flag.push(true);
500 }
501 });
502
503 order_by.iter().for_each(|e| {
504 if !final_sort_keys.contains(e) {
505 final_sort_keys.push(e.clone());
506 is_partition_flag.push(false);
507 }
508 });
509 let res = final_sort_keys
510 .into_iter()
511 .zip(is_partition_flag)
512 .collect::<Vec<_>>();
513 Ok(res)
514}
515
516pub fn compare_sort_expr(
519 sort_expr_a: &Sort,
520 sort_expr_b: &Sort,
521 schema: &DFSchemaRef,
522) -> Ordering {
523 let Sort {
524 expr: expr_a,
525 asc: asc_a,
526 nulls_first: nulls_first_a,
527 } = sort_expr_a;
528
529 let Sort {
530 expr: expr_b,
531 asc: asc_b,
532 nulls_first: nulls_first_b,
533 } = sort_expr_b;
534
535 let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
536 let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
537 for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
538 match idx_a.cmp(idx_b) {
539 Ordering::Less => {
540 return Ordering::Less;
541 }
542 Ordering::Greater => {
543 return Ordering::Greater;
544 }
545 Ordering::Equal => {}
546 }
547 }
548 match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
549 Ordering::Less => return Ordering::Greater,
550 Ordering::Greater => {
551 return Ordering::Less;
552 }
553 Ordering::Equal => {}
554 }
555 match (asc_a, asc_b) {
556 (true, false) => {
557 return Ordering::Greater;
558 }
559 (false, true) => {
560 return Ordering::Less;
561 }
562 _ => {}
563 }
564 match (nulls_first_a, nulls_first_b) {
565 (true, false) => {
566 return Ordering::Less;
567 }
568 (false, true) => {
569 return Ordering::Greater;
570 }
571 _ => {}
572 }
573 Ordering::Equal
574}
575
576pub fn group_window_expr_by_sort_keys(
578 window_expr: impl IntoIterator<Item = Expr>,
579) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
580 let mut result = vec![];
581 window_expr.into_iter().try_for_each(|expr| match &expr {
582 Expr::WindowFunction(window_fun) => {
583 let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
584 let sort_key = generate_sort_key(partition_by, order_by)?;
585 if let Some((_, values)) = result.iter_mut().find(
586 |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
587 ) {
588 values.push(expr);
589 } else {
590 result.push((sort_key, vec![expr]))
591 }
592 Ok(())
593 }
594 other => internal_err!(
595 "Impossibly got non-window expr {other:?}"
596 ),
597 })?;
598 Ok(result)
599}
600
601pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
605 find_exprs_in_exprs(exprs, &|nested_expr| {
606 matches!(nested_expr, Expr::AggregateFunction { .. })
607 })
608}
609
610pub fn find_window_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
613 find_exprs_in_exprs(exprs, &|nested_expr| {
614 matches!(nested_expr, Expr::WindowFunction { .. })
615 })
616}
617
618pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
621 find_exprs_in_expr(expr, &|nested_expr| {
622 matches!(nested_expr, Expr::OuterReferenceColumn { .. })
623 })
624}
625
626fn find_exprs_in_exprs<'a, F>(
630 exprs: impl IntoIterator<Item = &'a Expr>,
631 test_fn: &F,
632) -> Vec<Expr>
633where
634 F: Fn(&Expr) -> bool,
635{
636 exprs
637 .into_iter()
638 .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
639 .fold(vec![], |mut acc, expr| {
640 if !acc.contains(&expr) {
641 acc.push(expr)
642 }
643 acc
644 })
645}
646
647fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
651where
652 F: Fn(&Expr) -> bool,
653{
654 let mut exprs = vec![];
655 expr.apply(|expr| {
656 if test_fn(expr) {
657 if !(exprs.contains(expr)) {
658 exprs.push(expr.clone())
659 }
660 return Ok(TreeNodeRecursion::Jump);
662 }
663
664 Ok(TreeNodeRecursion::Continue)
665 })
666 .expect("no way to return error during recursion");
668 exprs
669}
670
671pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
673where
674 F: FnMut(&Expr) -> Result<(), E>,
675{
676 let mut err = Ok(());
677 expr.apply(|expr| {
678 if let Err(e) = f(expr) {
679 err = Err(e);
681 Ok(TreeNodeRecursion::Stop)
682 } else {
683 Ok(TreeNodeRecursion::Continue)
685 }
686 })
687 .expect("no way to return error during recursion");
689
690 err
691}
692
693pub fn exprlist_to_fields<'a>(
695 exprs: impl IntoIterator<Item = &'a Expr>,
696 plan: &LogicalPlan,
697) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
698 let input_schema = plan.schema();
700 exprs
701 .into_iter()
702 .map(|e| e.to_field(input_schema))
703 .collect()
704}
705
706pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
722 let output_exprs = match input.columnized_output_exprs() {
723 Ok(exprs) if !exprs.is_empty() => exprs,
724 _ => return Ok(e),
725 };
726 let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
727 e.transform_down(|node: Expr| match exprs_map.get(&node) {
728 Some(column) => Ok(Transformed::new(
729 Expr::Column(column.clone()),
730 true,
731 TreeNodeRecursion::Jump,
732 )),
733 None => Ok(Transformed::no(node)),
734 })
735 .data()
736}
737
738pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
741 exprs
742 .iter()
743 .flat_map(find_columns_referenced_by_expr)
744 .map(Expr::Column)
745 .collect()
746}
747
748pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
749 let mut exprs = vec![];
750 e.apply(|expr| {
751 if let Expr::Column(c) = expr {
752 exprs.push(c.clone())
753 }
754 Ok(TreeNodeRecursion::Continue)
755 })
756 .expect("Unexpected error");
758 exprs
759}
760
761pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
763 match expr {
764 Expr::Column(col) => {
765 let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
766 Ok(Expr::from(Column::from((qualifier, field))))
767 }
768 _ => Ok(Expr::Column(Column::from_name(
769 expr.schema_name().to_string(),
770 ))),
771 }
772}
773
774pub(crate) fn find_column_indexes_referenced_by_expr(
777 e: &Expr,
778 schema: &DFSchemaRef,
779) -> Vec<usize> {
780 let mut indexes = vec![];
781 e.apply(|expr| {
782 match expr {
783 Expr::Column(qc) => {
784 if let Ok(idx) = schema.index_of_column(qc) {
785 indexes.push(idx);
786 }
787 }
788 Expr::Literal(_, _) => {
789 indexes.push(usize::MAX);
790 }
791 _ => {}
792 }
793 Ok(TreeNodeRecursion::Continue)
794 })
795 .unwrap();
796 indexes
797}
798
799pub fn can_hash(data_type: &DataType) -> bool {
803 match data_type {
804 DataType::Null => true,
805 DataType::Boolean => true,
806 DataType::Int8 => true,
807 DataType::Int16 => true,
808 DataType::Int32 => true,
809 DataType::Int64 => true,
810 DataType::UInt8 => true,
811 DataType::UInt16 => true,
812 DataType::UInt32 => true,
813 DataType::UInt64 => true,
814 DataType::Float16 => true,
815 DataType::Float32 => true,
816 DataType::Float64 => true,
817 DataType::Decimal32(_, _) => true,
818 DataType::Decimal64(_, _) => true,
819 DataType::Decimal128(_, _) => true,
820 DataType::Decimal256(_, _) => true,
821 DataType::Timestamp(_, _) => true,
822 DataType::Utf8 => true,
823 DataType::LargeUtf8 => true,
824 DataType::Utf8View => true,
825 DataType::Binary => true,
826 DataType::LargeBinary => true,
827 DataType::BinaryView => true,
828 DataType::Date32 => true,
829 DataType::Date64 => true,
830 DataType::Time32(_) => true,
831 DataType::Time64(_) => true,
832 DataType::Duration(_) => true,
833 DataType::Interval(_) => true,
834 DataType::FixedSizeBinary(_) => true,
835 DataType::Dictionary(key_type, value_type) => {
836 DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
837 }
838 DataType::List(value_type) => can_hash(value_type.data_type()),
839 DataType::LargeList(value_type) => can_hash(value_type.data_type()),
840 DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
841 DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
842 DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
843
844 DataType::ListView(_)
845 | DataType::LargeListView(_)
846 | DataType::Union(_, _)
847 | DataType::RunEndEncoded(_, _) => false,
848 }
849}
850
851pub fn check_all_columns_from_schema(
853 columns: &HashSet<&Column>,
854 schema: &DFSchema,
855) -> Result<bool> {
856 for col in columns.iter() {
857 let exist = schema.is_column_from_schema(col);
858 if !exist {
859 return Ok(false);
860 }
861 }
862
863 Ok(true)
864}
865
866pub fn find_valid_equijoin_key_pair(
876 left_key: &Expr,
877 right_key: &Expr,
878 left_schema: &DFSchema,
879 right_schema: &DFSchema,
880) -> Result<Option<(Expr, Expr)>> {
881 let left_using_columns = left_key.column_refs();
882 let right_using_columns = right_key.column_refs();
883
884 if left_using_columns.is_empty() || right_using_columns.is_empty() {
886 return Ok(None);
887 }
888
889 if check_all_columns_from_schema(&left_using_columns, left_schema)?
890 && check_all_columns_from_schema(&right_using_columns, right_schema)?
891 {
892 return Ok(Some((left_key.clone(), right_key.clone())));
893 } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
894 && check_all_columns_from_schema(&left_using_columns, right_schema)?
895 {
896 return Ok(Some((right_key.clone(), left_key.clone())));
897 }
898
899 Ok(None)
900}
901
902pub fn generate_signature_error_msg(
914 func_name: &str,
915 func_signature: Signature,
916 input_expr_types: &[DataType],
917) -> String {
918 let candidate_signatures = func_signature
919 .type_signature
920 .to_string_repr()
921 .iter()
922 .map(|args_str| format!("\t{func_name}({args_str})"))
923 .collect::<Vec<String>>()
924 .join("\n");
925
926 format!(
927 "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
928 func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
929 )
930}
931
932pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
936 split_conjunction_impl(expr, vec![])
937}
938
939fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
940 match expr {
941 Expr::BinaryExpr(BinaryExpr {
942 right,
943 op: Operator::And,
944 left,
945 }) => {
946 let exprs = split_conjunction_impl(left, exprs);
947 split_conjunction_impl(right, exprs)
948 }
949 Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
950 other => {
951 exprs.push(other);
952 exprs
953 }
954 }
955}
956
957pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
961 let mut stack = vec![expr];
962 std::iter::from_fn(move || {
963 while let Some(expr) = stack.pop() {
964 match expr {
965 Expr::BinaryExpr(BinaryExpr {
966 right,
967 op: Operator::And,
968 left,
969 }) => {
970 stack.push(right);
971 stack.push(left);
972 }
973 Expr::Alias(Alias { expr, .. }) => stack.push(expr),
974 other => return Some(other),
975 }
976 }
977 None
978 })
979}
980
981pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
985 let mut stack = vec![expr];
986 std::iter::from_fn(move || {
987 while let Some(expr) = stack.pop() {
988 match expr {
989 Expr::BinaryExpr(BinaryExpr {
990 right,
991 op: Operator::And,
992 left,
993 }) => {
994 stack.push(*right);
995 stack.push(*left);
996 }
997 Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
998 other => return Some(other),
999 }
1000 }
1001 None
1002 })
1003}
1004
1005pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1027 split_binary_owned(expr, Operator::And)
1028}
1029
1030pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1053 split_binary_owned_impl(expr, op, vec![])
1054}
1055
1056fn split_binary_owned_impl(
1057 expr: Expr,
1058 operator: Operator,
1059 mut exprs: Vec<Expr>,
1060) -> Vec<Expr> {
1061 match expr {
1062 Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1063 let exprs = split_binary_owned_impl(*left, operator, exprs);
1064 split_binary_owned_impl(*right, operator, exprs)
1065 }
1066 Expr::Alias(Alias { expr, .. }) => {
1067 split_binary_owned_impl(*expr, operator, exprs)
1068 }
1069 other => {
1070 exprs.push(other);
1071 exprs
1072 }
1073 }
1074}
1075
1076pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1080 split_binary_impl(expr, op, vec![])
1081}
1082
1083fn split_binary_impl<'a>(
1084 expr: &'a Expr,
1085 operator: Operator,
1086 mut exprs: Vec<&'a Expr>,
1087) -> Vec<&'a Expr> {
1088 match expr {
1089 Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1090 let exprs = split_binary_impl(left, operator, exprs);
1091 split_binary_impl(right, operator, exprs)
1092 }
1093 Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1094 other => {
1095 exprs.push(other);
1096 exprs
1097 }
1098 }
1099}
1100
1101pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1124 filters.into_iter().reduce(Expr::and)
1125}
1126
1127pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1150 filters.into_iter().reduce(Expr::or)
1151}
1152
1153pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1168 let predicate = predicates
1170 .iter()
1171 .skip(1)
1172 .fold(predicates[0].clone(), |acc, predicate| {
1173 and(acc, (*predicate).to_owned())
1174 });
1175
1176 Ok(LogicalPlan::Filter(Filter::try_new(
1177 predicate,
1178 Arc::new(plan),
1179 )?))
1180}
1181
1182pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1193 let mut joins = vec![];
1194 let mut others = vec![];
1195 for filter in exprs.into_iter() {
1196 if filter.contains_outer() {
1198 if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1199 {
1200 joins.push(strip_outer_reference((*filter).clone()));
1201 }
1202 } else {
1203 others.push((*filter).clone());
1204 }
1205 }
1206
1207 Ok((joins, others))
1208}
1209
1210pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1220 match slice {
1221 [it] => Ok(it),
1222 [] => plan_err!("No items found!"),
1223 _ => plan_err!("More than one item found!"),
1224 }
1225}
1226
1227pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1232 if inputs.len() == 1 {
1233 inputs[0].schema().as_ref().clone()
1234 } else {
1235 inputs.iter().map(|input| input.schema()).fold(
1236 DFSchema::empty(),
1237 |mut lhs, rhs| {
1238 lhs.merge(rhs);
1239 lhs
1240 },
1241 )
1242 }
1243}
1244
1245pub fn format_state_name(name: &str, state_name: &str) -> String {
1247 format!("{name}[{state_name}]")
1248}
1249
1250pub fn collect_subquery_cols(
1252 exprs: &[Expr],
1253 subquery_schema: &DFSchema,
1254) -> Result<BTreeSet<Column>> {
1255 exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1256 let mut using_cols: Vec<Column> = vec![];
1257 for col in expr.column_refs().into_iter() {
1258 if subquery_schema.has_column(col) {
1259 using_cols.push(col.clone());
1260 }
1261 }
1262
1263 cols.extend(using_cols);
1264 Result::<_>::Ok(cols)
1265 })
1266}
1267
1268#[cfg(test)]
1269mod tests {
1270 use super::*;
1271 use crate::{
1272 col, cube,
1273 expr::WindowFunction,
1274 expr_vec_fmt, grouping_set, lit, rollup,
1275 test::function_stub::{max_udaf, min_udaf, sum_udaf},
1276 Cast, ExprFunctionExt, WindowFunctionDefinition,
1277 };
1278 use arrow::datatypes::{UnionFields, UnionMode};
1279
1280 #[test]
1281 fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1282 let result = group_window_expr_by_sort_keys(vec![])?;
1283 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1284 assert_eq!(expected, result);
1285 Ok(())
1286 }
1287
1288 #[test]
1289 fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1290 let max1 = Expr::from(WindowFunction::new(
1291 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1292 vec![col("name")],
1293 ));
1294 let max2 = Expr::from(WindowFunction::new(
1295 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1296 vec![col("name")],
1297 ));
1298 let min3 = Expr::from(WindowFunction::new(
1299 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1300 vec![col("name")],
1301 ));
1302 let sum4 = Expr::from(WindowFunction::new(
1303 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1304 vec![col("age")],
1305 ));
1306 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1307 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1308 let key = vec![];
1309 let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1310 vec![(key, vec![max1, max2, min3, sum4])];
1311 assert_eq!(expected, result);
1312 Ok(())
1313 }
1314
1315 #[test]
1316 fn test_group_window_expr_by_sort_keys() -> Result<()> {
1317 let age_asc = Sort::new(col("age"), true, true);
1318 let name_desc = Sort::new(col("name"), false, true);
1319 let created_at_desc = Sort::new(col("created_at"), false, true);
1320 let max1 = Expr::from(WindowFunction::new(
1321 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1322 vec![col("name")],
1323 ))
1324 .order_by(vec![age_asc.clone(), name_desc.clone()])
1325 .build()
1326 .unwrap();
1327 let max2 = Expr::from(WindowFunction::new(
1328 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1329 vec![col("name")],
1330 ));
1331 let min3 = Expr::from(WindowFunction::new(
1332 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1333 vec![col("name")],
1334 ))
1335 .order_by(vec![age_asc.clone(), name_desc.clone()])
1336 .build()
1337 .unwrap();
1338 let sum4 = Expr::from(WindowFunction::new(
1339 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1340 vec![col("age")],
1341 ))
1342 .order_by(vec![
1343 name_desc.clone(),
1344 age_asc.clone(),
1345 created_at_desc.clone(),
1346 ])
1347 .build()
1348 .unwrap();
1349 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1351 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1352
1353 let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1354 let key2 = vec![];
1355 let key3 = vec![
1356 (name_desc, false),
1357 (age_asc, false),
1358 (created_at_desc, false),
1359 ];
1360
1361 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1362 (key1, vec![max1, min3]),
1363 (key2, vec![max2]),
1364 (key3, vec![sum4]),
1365 ];
1366 assert_eq!(expected, result);
1367 Ok(())
1368 }
1369
1370 #[test]
1371 fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1372 let asc_or_desc = [true, false];
1373 let nulls_first_or_last = [true, false];
1374 let partition_by = &[col("age"), col("name"), col("created_at")];
1375 for asc_ in asc_or_desc {
1376 for nulls_first_ in nulls_first_or_last {
1377 let order_by = &[
1378 Sort {
1379 expr: col("age"),
1380 asc: asc_,
1381 nulls_first: nulls_first_,
1382 },
1383 Sort {
1384 expr: col("name"),
1385 asc: asc_,
1386 nulls_first: nulls_first_,
1387 },
1388 ];
1389
1390 let expected = vec![
1391 (
1392 Sort {
1393 expr: col("age"),
1394 asc: asc_,
1395 nulls_first: nulls_first_,
1396 },
1397 true,
1398 ),
1399 (
1400 Sort {
1401 expr: col("name"),
1402 asc: asc_,
1403 nulls_first: nulls_first_,
1404 },
1405 true,
1406 ),
1407 (
1408 Sort {
1409 expr: col("created_at"),
1410 asc: true,
1411 nulls_first: false,
1412 },
1413 true,
1414 ),
1415 ];
1416 let result = generate_sort_key(partition_by, order_by)?;
1417 assert_eq!(expected, result);
1418 }
1419 }
1420 Ok(())
1421 }
1422
1423 #[test]
1424 fn test_enumerate_grouping_sets() -> Result<()> {
1425 let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1426 let simple_col = col("simple_col");
1427 let cube = cube(multi_cols.clone());
1428 let rollup = rollup(multi_cols.clone());
1429 let grouping_set = grouping_set(vec![multi_cols]);
1430
1431 let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1433 let result = format!("[{}]", expr_vec_fmt!(sets));
1434 assert_eq!("[simple_col]", &result);
1435
1436 let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1438 let result = format!("[{}]", expr_vec_fmt!(sets));
1439 assert_eq!("[CUBE (col1, col2, col3)]", &result);
1440
1441 let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1443 let result = format!("[{}]", expr_vec_fmt!(sets));
1444 assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1445
1446 let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1448 let result = format!("[{}]", expr_vec_fmt!(sets));
1449 assert_eq!(
1450 "[GROUPING SETS (\
1451 (simple_col), \
1452 (simple_col, col1), \
1453 (simple_col, col2), \
1454 (simple_col, col1, col2), \
1455 (simple_col, col3), \
1456 (simple_col, col1, col3), \
1457 (simple_col, col2, col3), \
1458 (simple_col, col1, col2, col3))]",
1459 &result
1460 );
1461
1462 let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1464 let result = format!("[{}]", expr_vec_fmt!(sets));
1465 assert_eq!(
1466 "[GROUPING SETS (\
1467 (simple_col), \
1468 (simple_col, col1), \
1469 (simple_col, col1, col2), \
1470 (simple_col, col1, col2, col3))]",
1471 &result
1472 );
1473
1474 let sets =
1476 enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1477 let result = format!("[{}]", expr_vec_fmt!(sets));
1478 assert_eq!(
1479 "[GROUPING SETS (\
1480 (simple_col, col1, col2, col3))]",
1481 &result
1482 );
1483
1484 let sets = enumerate_grouping_sets(vec![
1486 simple_col.clone(),
1487 grouping_set,
1488 rollup.clone(),
1489 ])?;
1490 let result = format!("[{}]", expr_vec_fmt!(sets));
1491 assert_eq!(
1492 "[GROUPING SETS (\
1493 (simple_col, col1, col2, col3), \
1494 (simple_col, col1, col2, col3, col1), \
1495 (simple_col, col1, col2, col3, col1, col2), \
1496 (simple_col, col1, col2, col3, col1, col2, col3))]",
1497 &result
1498 );
1499
1500 let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1502 let result = format!("[{}]", expr_vec_fmt!(sets));
1503 assert_eq!(
1504 "[GROUPING SETS (\
1505 (simple_col), \
1506 (simple_col, col1), \
1507 (simple_col, col1, col2), \
1508 (simple_col, col1, col2, col3), \
1509 (simple_col, col1), \
1510 (simple_col, col1, col1), \
1511 (simple_col, col1, col1, col2), \
1512 (simple_col, col1, col1, col2, col3), \
1513 (simple_col, col2), \
1514 (simple_col, col2, col1), \
1515 (simple_col, col2, col1, col2), \
1516 (simple_col, col2, col1, col2, col3), \
1517 (simple_col, col1, col2), \
1518 (simple_col, col1, col2, col1), \
1519 (simple_col, col1, col2, col1, col2), \
1520 (simple_col, col1, col2, col1, col2, col3), \
1521 (simple_col, col3), \
1522 (simple_col, col3, col1), \
1523 (simple_col, col3, col1, col2), \
1524 (simple_col, col3, col1, col2, col3), \
1525 (simple_col, col1, col3), \
1526 (simple_col, col1, col3, col1), \
1527 (simple_col, col1, col3, col1, col2), \
1528 (simple_col, col1, col3, col1, col2, col3), \
1529 (simple_col, col2, col3), \
1530 (simple_col, col2, col3, col1), \
1531 (simple_col, col2, col3, col1, col2), \
1532 (simple_col, col2, col3, col1, col2, col3), \
1533 (simple_col, col1, col2, col3), \
1534 (simple_col, col1, col2, col3, col1), \
1535 (simple_col, col1, col2, col3, col1, col2), \
1536 (simple_col, col1, col2, col3, col1, col2, col3))]",
1537 &result
1538 );
1539
1540 Ok(())
1541 }
1542 #[test]
1543 fn test_split_conjunction() {
1544 let expr = col("a");
1545 let result = split_conjunction(&expr);
1546 assert_eq!(result, vec![&expr]);
1547 }
1548
1549 #[test]
1550 fn test_split_conjunction_two() {
1551 let expr = col("a").eq(lit(5)).and(col("b"));
1552 let expr1 = col("a").eq(lit(5));
1553 let expr2 = col("b");
1554
1555 let result = split_conjunction(&expr);
1556 assert_eq!(result, vec![&expr1, &expr2]);
1557 }
1558
1559 #[test]
1560 fn test_split_conjunction_alias() {
1561 let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1562 let expr1 = col("a").eq(lit(5));
1563 let expr2 = col("b"); let result = split_conjunction(&expr);
1566 assert_eq!(result, vec![&expr1, &expr2]);
1567 }
1568
1569 #[test]
1570 fn test_split_conjunction_or() {
1571 let expr = col("a").eq(lit(5)).or(col("b"));
1572 let result = split_conjunction(&expr);
1573 assert_eq!(result, vec![&expr]);
1574 }
1575
1576 #[test]
1577 fn test_split_binary_owned() {
1578 let expr = col("a");
1579 assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1580 }
1581
1582 #[test]
1583 fn test_split_binary_owned_two() {
1584 assert_eq!(
1585 split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1586 vec![col("a").eq(lit(5)), col("b")]
1587 );
1588 }
1589
1590 #[test]
1591 fn test_split_binary_owned_different_op() {
1592 let expr = col("a").eq(lit(5)).or(col("b"));
1593 assert_eq!(
1594 split_binary_owned(expr.clone(), Operator::And),
1596 vec![expr]
1597 );
1598 }
1599
1600 #[test]
1601 fn test_split_conjunction_owned() {
1602 let expr = col("a");
1603 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1604 }
1605
1606 #[test]
1607 fn test_split_conjunction_owned_two() {
1608 assert_eq!(
1609 split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1610 vec![col("a").eq(lit(5)), col("b")]
1611 );
1612 }
1613
1614 #[test]
1615 fn test_split_conjunction_owned_alias() {
1616 assert_eq!(
1617 split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1618 vec![
1619 col("a").eq(lit(5)),
1620 col("b"),
1622 ]
1623 );
1624 }
1625
1626 #[test]
1627 fn test_conjunction_empty() {
1628 assert_eq!(conjunction(vec![]), None);
1629 }
1630
1631 #[test]
1632 fn test_conjunction() {
1633 let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1635
1636 assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1638
1639 assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1641 }
1642
1643 #[test]
1644 fn test_disjunction_empty() {
1645 assert_eq!(disjunction(vec![]), None);
1646 }
1647
1648 #[test]
1649 fn test_disjunction() {
1650 let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1652
1653 assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1655
1656 assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1658 }
1659
1660 #[test]
1661 fn test_split_conjunction_owned_or() {
1662 let expr = col("a").eq(lit(5)).or(col("b"));
1663 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1664 }
1665
1666 #[test]
1667 fn test_collect_expr() -> Result<()> {
1668 let mut accum: HashSet<Column> = HashSet::new();
1669 expr_to_columns(
1670 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1671 &mut accum,
1672 )?;
1673 expr_to_columns(
1674 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1675 &mut accum,
1676 )?;
1677 assert_eq!(1, accum.len());
1678 assert!(accum.contains(&Column::from_name("a")));
1679 Ok(())
1680 }
1681
1682 #[test]
1683 fn test_can_hash() {
1684 let union_fields: UnionFields = [
1685 (0, Arc::new(Field::new("A", DataType::Int32, true))),
1686 (1, Arc::new(Field::new("B", DataType::Float64, true))),
1687 ]
1688 .into_iter()
1689 .collect();
1690
1691 let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1692 assert!(!can_hash(&union_type));
1693
1694 let list_union_type =
1695 DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1696 assert!(!can_hash(&list_union_type));
1697 }
1698}