1use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27 and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37 internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap,
38 Result, TableReference,
39};
40
41use indexmap::IndexSet;
42use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
43
44pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
45
46pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
49
50pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
53 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
54 if group_expr.len() > 1 {
55 return plan_err!(
56 "Invalid group by expressions, GroupingSet must be the only expression"
57 );
58 }
59 Ok(grouping_set.distinct_expr().len() + 1)
61 } else {
62 grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
63 }
64}
65
66fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
84 if slice.len() >= 64 {
85 return Err("The size of the set must be less than 64.".into());
86 }
87
88 let mut v = Vec::new();
89 for mask in 0..(1 << slice.len()) {
90 let mut ss = vec![];
91 let mut bitset = mask;
92 while bitset > 0 {
93 let rightmost: u64 = bitset & !(bitset - 1);
94 let idx = rightmost.trailing_zeros();
95 let item = slice.get(idx as usize).unwrap();
96 ss.push(item);
97 bitset &= bitset - 1;
99 }
100 v.push(ss);
101 }
102 Ok(v)
103}
104
105fn check_grouping_set_size_limit(size: usize) -> Result<()> {
107 let max_grouping_set_size = 65535;
108 if size > max_grouping_set_size {
109 return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
110 }
111
112 Ok(())
113}
114
115fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
117 let max_grouping_sets_size = 4096;
118 if size > max_grouping_sets_size {
119 return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
120 }
121
122 Ok(())
123}
124
125fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
137 check_grouping_set_size_limit(left.len() + right.len())?;
138 Ok(left.iter().chain(right.iter()).cloned().collect())
139}
140
141fn cross_join_grouping_sets<T: Clone>(
154 left: &[Vec<T>],
155 right: &[Vec<T>],
156) -> Result<Vec<Vec<T>>> {
157 let grouping_sets_size = left.len() * right.len();
158
159 check_grouping_sets_size_limit(grouping_sets_size)?;
160
161 let mut result = Vec::with_capacity(grouping_sets_size);
162 for le in left {
163 for re in right {
164 result.push(merge_grouping_set(le, re)?);
165 }
166 }
167 Ok(result)
168}
169
170pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
191 let has_grouping_set = group_expr
192 .iter()
193 .any(|expr| matches!(expr, Expr::GroupingSet(_)));
194 if !has_grouping_set || group_expr.len() == 1 {
195 return Ok(group_expr);
196 }
197 let partial_sets = group_expr
199 .iter()
200 .map(|expr| {
201 let exprs = match expr {
202 Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
203 check_grouping_sets_size_limit(grouping_sets.len())?;
204 grouping_sets.iter().map(|e| e.iter().collect()).collect()
205 }
206 Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
207 let grouping_sets = powerset(group_exprs)
208 .map_err(|e| plan_datafusion_err!("{}", e))?;
209 check_grouping_sets_size_limit(grouping_sets.len())?;
210 grouping_sets
211 }
212 Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
213 let size = group_exprs.len();
214 let slice = group_exprs.as_slice();
215 check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
216 (0..(size + 1))
217 .map(|i| slice[0..i].iter().collect())
218 .collect()
219 }
220 expr => vec![vec![expr]],
221 };
222 Ok(exprs)
223 })
224 .collect::<Result<Vec<_>>>()?;
225
226 let grouping_sets = partial_sets
228 .into_iter()
229 .map(Ok)
230 .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
231 .transpose()?
232 .map(|e| {
233 e.into_iter()
234 .map(|e| e.into_iter().cloned().collect())
235 .collect()
236 })
237 .unwrap_or_default();
238
239 Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
240 grouping_sets,
241 ))])
242}
243
244pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
247 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
248 if group_expr.len() > 1 {
249 return plan_err!(
250 "Invalid group by expressions, GroupingSet must be the only expression"
251 );
252 }
253 Ok(grouping_set.distinct_expr())
254 } else {
255 Ok(group_expr
256 .iter()
257 .collect::<IndexSet<_>>()
258 .into_iter()
259 .collect())
260 }
261}
262
263pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
266 expr.apply(|expr| {
267 match expr {
268 Expr::Column(qc) => {
269 accum.insert(qc.clone());
270 }
271 #[expect(deprecated)]
276 Expr::Unnest(_)
277 | Expr::ScalarVariable(_, _)
278 | Expr::Alias(_)
279 | Expr::Literal(_, _)
280 | Expr::BinaryExpr { .. }
281 | Expr::Like { .. }
282 | Expr::SimilarTo { .. }
283 | Expr::Not(_)
284 | Expr::IsNotNull(_)
285 | Expr::IsNull(_)
286 | Expr::IsTrue(_)
287 | Expr::IsFalse(_)
288 | Expr::IsUnknown(_)
289 | Expr::IsNotTrue(_)
290 | Expr::IsNotFalse(_)
291 | Expr::IsNotUnknown(_)
292 | Expr::Negative(_)
293 | Expr::Between { .. }
294 | Expr::Case { .. }
295 | Expr::Cast { .. }
296 | Expr::TryCast { .. }
297 | Expr::ScalarFunction(..)
298 | Expr::WindowFunction { .. }
299 | Expr::AggregateFunction { .. }
300 | Expr::GroupingSet(_)
301 | Expr::InList { .. }
302 | Expr::Exists { .. }
303 | Expr::InSubquery(_)
304 | Expr::ScalarSubquery(_)
305 | Expr::Wildcard { .. }
306 | Expr::Placeholder(_)
307 | Expr::OuterReferenceColumn { .. } => {}
308 }
309 Ok(TreeNodeRecursion::Continue)
310 })
311 .map(|_| ())
312}
313
314fn get_excluded_columns(
317 opt_exclude: Option<&ExcludeSelectItem>,
318 opt_except: Option<&ExceptSelectItem>,
319 schema: &DFSchema,
320 qualifier: Option<&TableReference>,
321) -> Result<Vec<Column>> {
322 let mut idents = vec![];
323 if let Some(excepts) = opt_except {
324 idents.push(&excepts.first_element);
325 idents.extend(&excepts.additional_elements);
326 }
327 if let Some(exclude) = opt_exclude {
328 match exclude {
329 ExcludeSelectItem::Single(ident) => idents.push(ident),
330 ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
331 }
332 }
333 let n_elem = idents.len();
335 let unique_idents = idents.into_iter().collect::<HashSet<_>>();
336 if n_elem != unique_idents.len() {
339 return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
340 }
341
342 let mut result = vec![];
343 for ident in unique_idents.into_iter() {
344 let col_name = ident.value.as_str();
345 let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
346 result.push(Column::from((qualifier, field)));
347 }
348 Ok(result)
349}
350
351fn get_exprs_except_skipped(
353 schema: &DFSchema,
354 columns_to_skip: HashSet<Column>,
355) -> Vec<Expr> {
356 if columns_to_skip.is_empty() {
357 schema.iter().map(Expr::from).collect::<Vec<Expr>>()
358 } else {
359 schema
360 .columns()
361 .iter()
362 .filter_map(|c| {
363 if !columns_to_skip.contains(c) {
364 Some(Expr::Column(c.clone()))
365 } else {
366 None
367 }
368 })
369 .collect::<Vec<Expr>>()
370 }
371}
372
373fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
377 let using_columns = plan.using_columns()?;
378 let excluded = using_columns
379 .into_iter()
380 .flat_map(|cols| {
382 let mut cols = cols.into_iter().collect::<Vec<_>>();
383 cols.sort();
386 let mut out_column_names: HashSet<String> = HashSet::new();
387 cols.into_iter().filter_map(move |c| {
388 if out_column_names.contains(&c.name) {
389 Some(c)
390 } else {
391 out_column_names.insert(c.name);
392 None
393 }
394 })
395 })
396 .collect::<HashSet<_>>();
397 Ok(excluded)
398}
399
400pub fn expand_wildcard(
402 schema: &DFSchema,
403 plan: &LogicalPlan,
404 wildcard_options: Option<&WildcardOptions>,
405) -> Result<Vec<Expr>> {
406 let mut columns_to_skip = exclude_using_columns(plan)?;
407 let excluded_columns = if let Some(WildcardOptions {
408 exclude: opt_exclude,
409 except: opt_except,
410 ..
411 }) = wildcard_options
412 {
413 get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
414 } else {
415 vec![]
416 };
417 columns_to_skip.extend(excluded_columns);
419 Ok(get_exprs_except_skipped(schema, columns_to_skip))
420}
421
422pub fn expand_qualified_wildcard(
424 qualifier: &TableReference,
425 schema: &DFSchema,
426 wildcard_options: Option<&WildcardOptions>,
427) -> Result<Vec<Expr>> {
428 let qualified_indices = schema.fields_indices_with_qualified(qualifier);
429 let projected_func_dependencies = schema
430 .functional_dependencies()
431 .project_functional_dependencies(&qualified_indices, qualified_indices.len());
432 let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
433 if fields_with_qualified.is_empty() {
434 return plan_err!("Invalid qualifier {qualifier}");
435 }
436
437 let qualified_schema = Arc::new(Schema::new_with_metadata(
438 fields_with_qualified,
439 schema.metadata().clone(),
440 ));
441 let qualified_dfschema =
442 DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
443 .with_functional_dependencies(projected_func_dependencies)?;
444 let excluded_columns = if let Some(WildcardOptions {
445 exclude: opt_exclude,
446 except: opt_except,
447 ..
448 }) = wildcard_options
449 {
450 get_excluded_columns(
451 opt_exclude.as_ref(),
452 opt_except.as_ref(),
453 schema,
454 Some(qualifier),
455 )?
456 } else {
457 vec![]
458 };
459 let mut columns_to_skip = HashSet::new();
461 columns_to_skip.extend(excluded_columns);
462 Ok(get_exprs_except_skipped(
463 &qualified_dfschema,
464 columns_to_skip,
465 ))
466}
467
468type WindowSortKey = Vec<(Sort, bool)>;
471
472pub fn generate_sort_key(
474 partition_by: &[Expr],
475 order_by: &[Sort],
476) -> Result<WindowSortKey> {
477 let normalized_order_by_keys = order_by
478 .iter()
479 .map(|e| {
480 let Sort { expr, .. } = e;
481 Sort::new(expr.clone(), true, false)
482 })
483 .collect::<Vec<_>>();
484
485 let mut final_sort_keys = vec![];
486 let mut is_partition_flag = vec![];
487 partition_by.iter().for_each(|e| {
488 let e = e.clone().sort(true, false);
491 if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
492 let order_by_key = &order_by[pos];
493 if !final_sort_keys.contains(order_by_key) {
494 final_sort_keys.push(order_by_key.clone());
495 is_partition_flag.push(true);
496 }
497 } else if !final_sort_keys.contains(&e) {
498 final_sort_keys.push(e);
499 is_partition_flag.push(true);
500 }
501 });
502
503 order_by.iter().for_each(|e| {
504 if !final_sort_keys.contains(e) {
505 final_sort_keys.push(e.clone());
506 is_partition_flag.push(false);
507 }
508 });
509 let res = final_sort_keys
510 .into_iter()
511 .zip(is_partition_flag)
512 .collect::<Vec<_>>();
513 Ok(res)
514}
515
516pub fn compare_sort_expr(
519 sort_expr_a: &Sort,
520 sort_expr_b: &Sort,
521 schema: &DFSchemaRef,
522) -> Ordering {
523 let Sort {
524 expr: expr_a,
525 asc: asc_a,
526 nulls_first: nulls_first_a,
527 } = sort_expr_a;
528
529 let Sort {
530 expr: expr_b,
531 asc: asc_b,
532 nulls_first: nulls_first_b,
533 } = sort_expr_b;
534
535 let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
536 let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
537 for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
538 match idx_a.cmp(idx_b) {
539 Ordering::Less => {
540 return Ordering::Less;
541 }
542 Ordering::Greater => {
543 return Ordering::Greater;
544 }
545 Ordering::Equal => {}
546 }
547 }
548 match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
549 Ordering::Less => return Ordering::Greater,
550 Ordering::Greater => {
551 return Ordering::Less;
552 }
553 Ordering::Equal => {}
554 }
555 match (asc_a, asc_b) {
556 (true, false) => {
557 return Ordering::Greater;
558 }
559 (false, true) => {
560 return Ordering::Less;
561 }
562 _ => {}
563 }
564 match (nulls_first_a, nulls_first_b) {
565 (true, false) => {
566 return Ordering::Less;
567 }
568 (false, true) => {
569 return Ordering::Greater;
570 }
571 _ => {}
572 }
573 Ordering::Equal
574}
575
576pub fn group_window_expr_by_sort_keys(
578 window_expr: impl IntoIterator<Item = Expr>,
579) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
580 let mut result = vec![];
581 window_expr.into_iter().try_for_each(|expr| match &expr {
582 Expr::WindowFunction(window_fun) => {
583 let WindowFunctionParams{ partition_by, order_by, ..} = &window_fun.as_ref().params;
584 let sort_key = generate_sort_key(partition_by, order_by)?;
585 if let Some((_, values)) = result.iter_mut().find(
586 |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
587 ) {
588 values.push(expr);
589 } else {
590 result.push((sort_key, vec![expr]))
591 }
592 Ok(())
593 }
594 other => internal_err!(
595 "Impossibly got non-window expr {other:?}"
596 ),
597 })?;
598 Ok(result)
599}
600
601pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
605 find_exprs_in_exprs(exprs, &|nested_expr| {
606 matches!(nested_expr, Expr::AggregateFunction { .. })
607 })
608}
609
610pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
613 find_exprs_in_exprs(exprs, &|nested_expr| {
614 matches!(nested_expr, Expr::WindowFunction { .. })
615 })
616}
617
618pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
621 find_exprs_in_expr(expr, &|nested_expr| {
622 matches!(nested_expr, Expr::OuterReferenceColumn { .. })
623 })
624}
625
626fn find_exprs_in_exprs<'a, F>(
630 exprs: impl IntoIterator<Item = &'a Expr>,
631 test_fn: &F,
632) -> Vec<Expr>
633where
634 F: Fn(&Expr) -> bool,
635{
636 exprs
637 .into_iter()
638 .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
639 .fold(vec![], |mut acc, expr| {
640 if !acc.contains(&expr) {
641 acc.push(expr)
642 }
643 acc
644 })
645}
646
647fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
651where
652 F: Fn(&Expr) -> bool,
653{
654 let mut exprs = vec![];
655 expr.apply(|expr| {
656 if test_fn(expr) {
657 if !(exprs.contains(expr)) {
658 exprs.push(expr.clone())
659 }
660 return Ok(TreeNodeRecursion::Jump);
662 }
663
664 Ok(TreeNodeRecursion::Continue)
665 })
666 .expect("no way to return error during recursion");
668 exprs
669}
670
671pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
673where
674 F: FnMut(&Expr) -> Result<(), E>,
675{
676 let mut err = Ok(());
677 expr.apply(|expr| {
678 if let Err(e) = f(expr) {
679 err = Err(e);
681 Ok(TreeNodeRecursion::Stop)
682 } else {
683 Ok(TreeNodeRecursion::Continue)
685 }
686 })
687 .expect("no way to return error during recursion");
689
690 err
691}
692
693pub fn exprlist_to_fields<'a>(
695 exprs: impl IntoIterator<Item = &'a Expr>,
696 plan: &LogicalPlan,
697) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
698 let input_schema = plan.schema();
700 exprs
701 .into_iter()
702 .map(|e| e.to_field(input_schema))
703 .collect()
704}
705
706pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
722 let output_exprs = match input.columnized_output_exprs() {
723 Ok(exprs) if !exprs.is_empty() => exprs,
724 _ => return Ok(e),
725 };
726 let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
727 e.transform_down(|node: Expr| match exprs_map.get(&node) {
728 Some(column) => Ok(Transformed::new(
729 Expr::Column(column.clone()),
730 true,
731 TreeNodeRecursion::Jump,
732 )),
733 None => Ok(Transformed::no(node)),
734 })
735 .data()
736}
737
738pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
741 exprs
742 .iter()
743 .flat_map(find_columns_referenced_by_expr)
744 .map(Expr::Column)
745 .collect()
746}
747
748pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
749 let mut exprs = vec![];
750 e.apply(|expr| {
751 if let Expr::Column(c) = expr {
752 exprs.push(c.clone())
753 }
754 Ok(TreeNodeRecursion::Continue)
755 })
756 .expect("Unexpected error");
758 exprs
759}
760
761pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
763 match expr {
764 Expr::Column(col) => {
765 let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
766 Ok(Expr::from(Column::from((qualifier, field))))
767 }
768 _ => Ok(Expr::Column(Column::from_name(
769 expr.schema_name().to_string(),
770 ))),
771 }
772}
773
774pub(crate) fn find_column_indexes_referenced_by_expr(
777 e: &Expr,
778 schema: &DFSchemaRef,
779) -> Vec<usize> {
780 let mut indexes = vec![];
781 e.apply(|expr| {
782 match expr {
783 Expr::Column(qc) => {
784 if let Ok(idx) = schema.index_of_column(qc) {
785 indexes.push(idx);
786 }
787 }
788 Expr::Literal(_, _) => {
789 indexes.push(usize::MAX);
790 }
791 _ => {}
792 }
793 Ok(TreeNodeRecursion::Continue)
794 })
795 .unwrap();
796 indexes
797}
798
799pub fn can_hash(data_type: &DataType) -> bool {
803 match data_type {
804 DataType::Null => true,
805 DataType::Boolean => true,
806 DataType::Int8 => true,
807 DataType::Int16 => true,
808 DataType::Int32 => true,
809 DataType::Int64 => true,
810 DataType::UInt8 => true,
811 DataType::UInt16 => true,
812 DataType::UInt32 => true,
813 DataType::UInt64 => true,
814 DataType::Float16 => true,
815 DataType::Float32 => true,
816 DataType::Float64 => true,
817 DataType::Decimal128(_, _) => true,
818 DataType::Decimal256(_, _) => true,
819 DataType::Timestamp(_, _) => true,
820 DataType::Utf8 => true,
821 DataType::LargeUtf8 => true,
822 DataType::Utf8View => true,
823 DataType::Binary => true,
824 DataType::LargeBinary => true,
825 DataType::BinaryView => true,
826 DataType::Date32 => true,
827 DataType::Date64 => true,
828 DataType::Time32(_) => true,
829 DataType::Time64(_) => true,
830 DataType::Duration(_) => true,
831 DataType::Interval(_) => true,
832 DataType::FixedSizeBinary(_) => true,
833 DataType::Dictionary(key_type, value_type) => {
834 DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
835 }
836 DataType::List(value_type) => can_hash(value_type.data_type()),
837 DataType::LargeList(value_type) => can_hash(value_type.data_type()),
838 DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
839 DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
840 DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
841
842 DataType::ListView(_)
843 | DataType::LargeListView(_)
844 | DataType::Union(_, _)
845 | DataType::RunEndEncoded(_, _) => false,
846 }
847}
848
849pub fn check_all_columns_from_schema(
851 columns: &HashSet<&Column>,
852 schema: &DFSchema,
853) -> Result<bool> {
854 for col in columns.iter() {
855 let exist = schema.is_column_from_schema(col);
856 if !exist {
857 return Ok(false);
858 }
859 }
860
861 Ok(true)
862}
863
864pub fn find_valid_equijoin_key_pair(
874 left_key: &Expr,
875 right_key: &Expr,
876 left_schema: &DFSchema,
877 right_schema: &DFSchema,
878) -> Result<Option<(Expr, Expr)>> {
879 let left_using_columns = left_key.column_refs();
880 let right_using_columns = right_key.column_refs();
881
882 if left_using_columns.is_empty() || right_using_columns.is_empty() {
884 return Ok(None);
885 }
886
887 if check_all_columns_from_schema(&left_using_columns, left_schema)?
888 && check_all_columns_from_schema(&right_using_columns, right_schema)?
889 {
890 return Ok(Some((left_key.clone(), right_key.clone())));
891 } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
892 && check_all_columns_from_schema(&left_using_columns, right_schema)?
893 {
894 return Ok(Some((right_key.clone(), left_key.clone())));
895 }
896
897 Ok(None)
898}
899
900pub fn generate_signature_error_msg(
912 func_name: &str,
913 func_signature: Signature,
914 input_expr_types: &[DataType],
915) -> String {
916 let candidate_signatures = func_signature
917 .type_signature
918 .to_string_repr()
919 .iter()
920 .map(|args_str| format!("\t{func_name}({args_str})"))
921 .collect::<Vec<String>>()
922 .join("\n");
923
924 format!(
925 "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
926 func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
927 )
928}
929
930pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
934 split_conjunction_impl(expr, vec![])
935}
936
937fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
938 match expr {
939 Expr::BinaryExpr(BinaryExpr {
940 right,
941 op: Operator::And,
942 left,
943 }) => {
944 let exprs = split_conjunction_impl(left, exprs);
945 split_conjunction_impl(right, exprs)
946 }
947 Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
948 other => {
949 exprs.push(other);
950 exprs
951 }
952 }
953}
954
955pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
959 let mut stack = vec![expr];
960 std::iter::from_fn(move || {
961 while let Some(expr) = stack.pop() {
962 match expr {
963 Expr::BinaryExpr(BinaryExpr {
964 right,
965 op: Operator::And,
966 left,
967 }) => {
968 stack.push(right);
969 stack.push(left);
970 }
971 Expr::Alias(Alias { expr, .. }) => stack.push(expr),
972 other => return Some(other),
973 }
974 }
975 None
976 })
977}
978
979pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
983 let mut stack = vec![expr];
984 std::iter::from_fn(move || {
985 while let Some(expr) = stack.pop() {
986 match expr {
987 Expr::BinaryExpr(BinaryExpr {
988 right,
989 op: Operator::And,
990 left,
991 }) => {
992 stack.push(*right);
993 stack.push(*left);
994 }
995 Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
996 other => return Some(other),
997 }
998 }
999 None
1000 })
1001}
1002
1003pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1025 split_binary_owned(expr, Operator::And)
1026}
1027
1028pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1051 split_binary_owned_impl(expr, op, vec![])
1052}
1053
1054fn split_binary_owned_impl(
1055 expr: Expr,
1056 operator: Operator,
1057 mut exprs: Vec<Expr>,
1058) -> Vec<Expr> {
1059 match expr {
1060 Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1061 let exprs = split_binary_owned_impl(*left, operator, exprs);
1062 split_binary_owned_impl(*right, operator, exprs)
1063 }
1064 Expr::Alias(Alias { expr, .. }) => {
1065 split_binary_owned_impl(*expr, operator, exprs)
1066 }
1067 other => {
1068 exprs.push(other);
1069 exprs
1070 }
1071 }
1072}
1073
1074pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1078 split_binary_impl(expr, op, vec![])
1079}
1080
1081fn split_binary_impl<'a>(
1082 expr: &'a Expr,
1083 operator: Operator,
1084 mut exprs: Vec<&'a Expr>,
1085) -> Vec<&'a Expr> {
1086 match expr {
1087 Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1088 let exprs = split_binary_impl(left, operator, exprs);
1089 split_binary_impl(right, operator, exprs)
1090 }
1091 Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1092 other => {
1093 exprs.push(other);
1094 exprs
1095 }
1096 }
1097}
1098
1099pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1122 filters.into_iter().reduce(Expr::and)
1123}
1124
1125pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1148 filters.into_iter().reduce(Expr::or)
1149}
1150
1151pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1166 let predicate = predicates
1168 .iter()
1169 .skip(1)
1170 .fold(predicates[0].clone(), |acc, predicate| {
1171 and(acc, (*predicate).to_owned())
1172 });
1173
1174 Ok(LogicalPlan::Filter(Filter::try_new(
1175 predicate,
1176 Arc::new(plan),
1177 )?))
1178}
1179
1180pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1191 let mut joins = vec![];
1192 let mut others = vec![];
1193 for filter in exprs.into_iter() {
1194 if filter.contains_outer() {
1196 if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1197 {
1198 joins.push(strip_outer_reference((*filter).clone()));
1199 }
1200 } else {
1201 others.push((*filter).clone());
1202 }
1203 }
1204
1205 Ok((joins, others))
1206}
1207
1208pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1218 match slice {
1219 [it] => Ok(it),
1220 [] => plan_err!("No items found!"),
1221 _ => plan_err!("More than one item found!"),
1222 }
1223}
1224
1225pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1227 if inputs.len() == 1 {
1228 inputs[0].schema().as_ref().clone()
1229 } else {
1230 inputs.iter().map(|input| input.schema()).fold(
1231 DFSchema::empty(),
1232 |mut lhs, rhs| {
1233 lhs.merge(rhs);
1234 lhs
1235 },
1236 )
1237 }
1238}
1239
1240pub fn format_state_name(name: &str, state_name: &str) -> String {
1242 format!("{name}[{state_name}]")
1243}
1244
1245pub fn collect_subquery_cols(
1247 exprs: &[Expr],
1248 subquery_schema: &DFSchema,
1249) -> Result<BTreeSet<Column>> {
1250 exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1251 let mut using_cols: Vec<Column> = vec![];
1252 for col in expr.column_refs().into_iter() {
1253 if subquery_schema.has_column(col) {
1254 using_cols.push(col.clone());
1255 }
1256 }
1257
1258 cols.extend(using_cols);
1259 Result::<_>::Ok(cols)
1260 })
1261}
1262
1263#[cfg(test)]
1264mod tests {
1265 use super::*;
1266 use crate::{
1267 col, cube,
1268 expr::WindowFunction,
1269 expr_vec_fmt, grouping_set, lit, rollup,
1270 test::function_stub::{max_udaf, min_udaf, sum_udaf},
1271 Cast, ExprFunctionExt, WindowFunctionDefinition,
1272 };
1273 use arrow::datatypes::{UnionFields, UnionMode};
1274
1275 #[test]
1276 fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1277 let result = group_window_expr_by_sort_keys(vec![])?;
1278 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1279 assert_eq!(expected, result);
1280 Ok(())
1281 }
1282
1283 #[test]
1284 fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1285 let max1 = Expr::from(WindowFunction::new(
1286 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1287 vec![col("name")],
1288 ));
1289 let max2 = Expr::from(WindowFunction::new(
1290 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1291 vec![col("name")],
1292 ));
1293 let min3 = Expr::from(WindowFunction::new(
1294 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1295 vec![col("name")],
1296 ));
1297 let sum4 = Expr::from(WindowFunction::new(
1298 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1299 vec![col("age")],
1300 ));
1301 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1302 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1303 let key = vec![];
1304 let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1305 vec![(key, vec![max1, max2, min3, sum4])];
1306 assert_eq!(expected, result);
1307 Ok(())
1308 }
1309
1310 #[test]
1311 fn test_group_window_expr_by_sort_keys() -> Result<()> {
1312 let age_asc = Sort::new(col("age"), true, true);
1313 let name_desc = Sort::new(col("name"), false, true);
1314 let created_at_desc = Sort::new(col("created_at"), false, true);
1315 let max1 = Expr::from(WindowFunction::new(
1316 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1317 vec![col("name")],
1318 ))
1319 .order_by(vec![age_asc.clone(), name_desc.clone()])
1320 .build()
1321 .unwrap();
1322 let max2 = Expr::from(WindowFunction::new(
1323 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1324 vec![col("name")],
1325 ));
1326 let min3 = Expr::from(WindowFunction::new(
1327 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1328 vec![col("name")],
1329 ))
1330 .order_by(vec![age_asc.clone(), name_desc.clone()])
1331 .build()
1332 .unwrap();
1333 let sum4 = Expr::from(WindowFunction::new(
1334 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1335 vec![col("age")],
1336 ))
1337 .order_by(vec![
1338 name_desc.clone(),
1339 age_asc.clone(),
1340 created_at_desc.clone(),
1341 ])
1342 .build()
1343 .unwrap();
1344 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1346 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1347
1348 let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1349 let key2 = vec![];
1350 let key3 = vec![
1351 (name_desc, false),
1352 (age_asc, false),
1353 (created_at_desc, false),
1354 ];
1355
1356 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1357 (key1, vec![max1, min3]),
1358 (key2, vec![max2]),
1359 (key3, vec![sum4]),
1360 ];
1361 assert_eq!(expected, result);
1362 Ok(())
1363 }
1364
1365 #[test]
1366 fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1367 let asc_or_desc = [true, false];
1368 let nulls_first_or_last = [true, false];
1369 let partition_by = &[col("age"), col("name"), col("created_at")];
1370 for asc_ in asc_or_desc {
1371 for nulls_first_ in nulls_first_or_last {
1372 let order_by = &[
1373 Sort {
1374 expr: col("age"),
1375 asc: asc_,
1376 nulls_first: nulls_first_,
1377 },
1378 Sort {
1379 expr: col("name"),
1380 asc: asc_,
1381 nulls_first: nulls_first_,
1382 },
1383 ];
1384
1385 let expected = vec![
1386 (
1387 Sort {
1388 expr: col("age"),
1389 asc: asc_,
1390 nulls_first: nulls_first_,
1391 },
1392 true,
1393 ),
1394 (
1395 Sort {
1396 expr: col("name"),
1397 asc: asc_,
1398 nulls_first: nulls_first_,
1399 },
1400 true,
1401 ),
1402 (
1403 Sort {
1404 expr: col("created_at"),
1405 asc: true,
1406 nulls_first: false,
1407 },
1408 true,
1409 ),
1410 ];
1411 let result = generate_sort_key(partition_by, order_by)?;
1412 assert_eq!(expected, result);
1413 }
1414 }
1415 Ok(())
1416 }
1417
1418 #[test]
1419 fn test_enumerate_grouping_sets() -> Result<()> {
1420 let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1421 let simple_col = col("simple_col");
1422 let cube = cube(multi_cols.clone());
1423 let rollup = rollup(multi_cols.clone());
1424 let grouping_set = grouping_set(vec![multi_cols]);
1425
1426 let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1428 let result = format!("[{}]", expr_vec_fmt!(sets));
1429 assert_eq!("[simple_col]", &result);
1430
1431 let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1433 let result = format!("[{}]", expr_vec_fmt!(sets));
1434 assert_eq!("[CUBE (col1, col2, col3)]", &result);
1435
1436 let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1438 let result = format!("[{}]", expr_vec_fmt!(sets));
1439 assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1440
1441 let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1443 let result = format!("[{}]", expr_vec_fmt!(sets));
1444 assert_eq!(
1445 "[GROUPING SETS (\
1446 (simple_col), \
1447 (simple_col, col1), \
1448 (simple_col, col2), \
1449 (simple_col, col1, col2), \
1450 (simple_col, col3), \
1451 (simple_col, col1, col3), \
1452 (simple_col, col2, col3), \
1453 (simple_col, col1, col2, col3))]",
1454 &result
1455 );
1456
1457 let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1459 let result = format!("[{}]", expr_vec_fmt!(sets));
1460 assert_eq!(
1461 "[GROUPING SETS (\
1462 (simple_col), \
1463 (simple_col, col1), \
1464 (simple_col, col1, col2), \
1465 (simple_col, col1, col2, col3))]",
1466 &result
1467 );
1468
1469 let sets =
1471 enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1472 let result = format!("[{}]", expr_vec_fmt!(sets));
1473 assert_eq!(
1474 "[GROUPING SETS (\
1475 (simple_col, col1, col2, col3))]",
1476 &result
1477 );
1478
1479 let sets = enumerate_grouping_sets(vec![
1481 simple_col.clone(),
1482 grouping_set,
1483 rollup.clone(),
1484 ])?;
1485 let result = format!("[{}]", expr_vec_fmt!(sets));
1486 assert_eq!(
1487 "[GROUPING SETS (\
1488 (simple_col, col1, col2, col3), \
1489 (simple_col, col1, col2, col3, col1), \
1490 (simple_col, col1, col2, col3, col1, col2), \
1491 (simple_col, col1, col2, col3, col1, col2, col3))]",
1492 &result
1493 );
1494
1495 let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1497 let result = format!("[{}]", expr_vec_fmt!(sets));
1498 assert_eq!(
1499 "[GROUPING SETS (\
1500 (simple_col), \
1501 (simple_col, col1), \
1502 (simple_col, col1, col2), \
1503 (simple_col, col1, col2, col3), \
1504 (simple_col, col1), \
1505 (simple_col, col1, col1), \
1506 (simple_col, col1, col1, col2), \
1507 (simple_col, col1, col1, col2, col3), \
1508 (simple_col, col2), \
1509 (simple_col, col2, col1), \
1510 (simple_col, col2, col1, col2), \
1511 (simple_col, col2, col1, col2, col3), \
1512 (simple_col, col1, col2), \
1513 (simple_col, col1, col2, col1), \
1514 (simple_col, col1, col2, col1, col2), \
1515 (simple_col, col1, col2, col1, col2, col3), \
1516 (simple_col, col3), \
1517 (simple_col, col3, col1), \
1518 (simple_col, col3, col1, col2), \
1519 (simple_col, col3, col1, col2, col3), \
1520 (simple_col, col1, col3), \
1521 (simple_col, col1, col3, col1), \
1522 (simple_col, col1, col3, col1, col2), \
1523 (simple_col, col1, col3, col1, col2, col3), \
1524 (simple_col, col2, col3), \
1525 (simple_col, col2, col3, col1), \
1526 (simple_col, col2, col3, col1, col2), \
1527 (simple_col, col2, col3, col1, col2, col3), \
1528 (simple_col, col1, col2, col3), \
1529 (simple_col, col1, col2, col3, col1), \
1530 (simple_col, col1, col2, col3, col1, col2), \
1531 (simple_col, col1, col2, col3, col1, col2, col3))]",
1532 &result
1533 );
1534
1535 Ok(())
1536 }
1537 #[test]
1538 fn test_split_conjunction() {
1539 let expr = col("a");
1540 let result = split_conjunction(&expr);
1541 assert_eq!(result, vec![&expr]);
1542 }
1543
1544 #[test]
1545 fn test_split_conjunction_two() {
1546 let expr = col("a").eq(lit(5)).and(col("b"));
1547 let expr1 = col("a").eq(lit(5));
1548 let expr2 = col("b");
1549
1550 let result = split_conjunction(&expr);
1551 assert_eq!(result, vec![&expr1, &expr2]);
1552 }
1553
1554 #[test]
1555 fn test_split_conjunction_alias() {
1556 let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1557 let expr1 = col("a").eq(lit(5));
1558 let expr2 = col("b"); let result = split_conjunction(&expr);
1561 assert_eq!(result, vec![&expr1, &expr2]);
1562 }
1563
1564 #[test]
1565 fn test_split_conjunction_or() {
1566 let expr = col("a").eq(lit(5)).or(col("b"));
1567 let result = split_conjunction(&expr);
1568 assert_eq!(result, vec![&expr]);
1569 }
1570
1571 #[test]
1572 fn test_split_binary_owned() {
1573 let expr = col("a");
1574 assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1575 }
1576
1577 #[test]
1578 fn test_split_binary_owned_two() {
1579 assert_eq!(
1580 split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1581 vec![col("a").eq(lit(5)), col("b")]
1582 );
1583 }
1584
1585 #[test]
1586 fn test_split_binary_owned_different_op() {
1587 let expr = col("a").eq(lit(5)).or(col("b"));
1588 assert_eq!(
1589 split_binary_owned(expr.clone(), Operator::And),
1591 vec![expr]
1592 );
1593 }
1594
1595 #[test]
1596 fn test_split_conjunction_owned() {
1597 let expr = col("a");
1598 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1599 }
1600
1601 #[test]
1602 fn test_split_conjunction_owned_two() {
1603 assert_eq!(
1604 split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1605 vec![col("a").eq(lit(5)), col("b")]
1606 );
1607 }
1608
1609 #[test]
1610 fn test_split_conjunction_owned_alias() {
1611 assert_eq!(
1612 split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1613 vec![
1614 col("a").eq(lit(5)),
1615 col("b"),
1617 ]
1618 );
1619 }
1620
1621 #[test]
1622 fn test_conjunction_empty() {
1623 assert_eq!(conjunction(vec![]), None);
1624 }
1625
1626 #[test]
1627 fn test_conjunction() {
1628 let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1630
1631 assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1633
1634 assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1636 }
1637
1638 #[test]
1639 fn test_disjunction_empty() {
1640 assert_eq!(disjunction(vec![]), None);
1641 }
1642
1643 #[test]
1644 fn test_disjunction() {
1645 let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1647
1648 assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1650
1651 assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1653 }
1654
1655 #[test]
1656 fn test_split_conjunction_owned_or() {
1657 let expr = col("a").eq(lit(5)).or(col("b"));
1658 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1659 }
1660
1661 #[test]
1662 fn test_collect_expr() -> Result<()> {
1663 let mut accum: HashSet<Column> = HashSet::new();
1664 expr_to_columns(
1665 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1666 &mut accum,
1667 )?;
1668 expr_to_columns(
1669 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1670 &mut accum,
1671 )?;
1672 assert_eq!(1, accum.len());
1673 assert!(accum.contains(&Column::from_name("a")));
1674 Ok(())
1675 }
1676
1677 #[test]
1678 fn test_can_hash() {
1679 let union_fields: UnionFields = [
1680 (0, Arc::new(Field::new("A", DataType::Int32, true))),
1681 (1, Arc::new(Field::new("B", DataType::Float64, true))),
1682 ]
1683 .into_iter()
1684 .collect();
1685
1686 let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1687 assert!(!can_hash(&union_type));
1688
1689 let list_union_type =
1690 DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1691 assert!(!can_hash(&list_union_type));
1692 }
1693}