1use std::cmp::Ordering;
21use std::collections::{BTreeSet, HashSet};
22use std::sync::Arc;
23
24use crate::expr::{Alias, Sort, WildcardOptions, WindowFunction, WindowFunctionParams};
25use crate::expr_rewriter::strip_outer_reference;
26use crate::{
27 and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
28};
29use datafusion_expr_common::signature::{Signature, TypeSignature};
30
31use arrow::datatypes::{DataType, Field, Schema};
32use datafusion_common::tree_node::{
33 Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34};
35use datafusion_common::utils::get_at_indices;
36use datafusion_common::{
37 internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, HashMap,
38 Result, TableReference,
39};
40
41use indexmap::IndexSet;
42use sqlparser::ast::{ExceptSelectItem, ExcludeSelectItem};
43
44pub use datafusion_functions_aggregate_common::order::AggregateOrderSensitivity;
45
46pub use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
49
50pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
53 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
54 if group_expr.len() > 1 {
55 return plan_err!(
56 "Invalid group by expressions, GroupingSet must be the only expression"
57 );
58 }
59 Ok(grouping_set.distinct_expr().len() + 1)
61 } else {
62 grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
63 }
64}
65
66fn powerset<T>(slice: &[T]) -> Result<Vec<Vec<&T>>, String> {
84 if slice.len() >= 64 {
85 return Err("The size of the set must be less than 64.".into());
86 }
87
88 let mut v = Vec::new();
89 for mask in 0..(1 << slice.len()) {
90 let mut ss = vec![];
91 let mut bitset = mask;
92 while bitset > 0 {
93 let rightmost: u64 = bitset & !(bitset - 1);
94 let idx = rightmost.trailing_zeros();
95 let item = slice.get(idx as usize).unwrap();
96 ss.push(item);
97 bitset &= bitset - 1;
99 }
100 v.push(ss);
101 }
102 Ok(v)
103}
104
105fn check_grouping_set_size_limit(size: usize) -> Result<()> {
107 let max_grouping_set_size = 65535;
108 if size > max_grouping_set_size {
109 return plan_err!("The number of group_expression in grouping_set exceeds the maximum limit {max_grouping_set_size}, found {size}");
110 }
111
112 Ok(())
113}
114
115fn check_grouping_sets_size_limit(size: usize) -> Result<()> {
117 let max_grouping_sets_size = 4096;
118 if size > max_grouping_sets_size {
119 return plan_err!("The number of grouping_set in grouping_sets exceeds the maximum limit {max_grouping_sets_size}, found {size}");
120 }
121
122 Ok(())
123}
124
125fn merge_grouping_set<T: Clone>(left: &[T], right: &[T]) -> Result<Vec<T>> {
137 check_grouping_set_size_limit(left.len() + right.len())?;
138 Ok(left.iter().chain(right.iter()).cloned().collect())
139}
140
141fn cross_join_grouping_sets<T: Clone>(
154 left: &[Vec<T>],
155 right: &[Vec<T>],
156) -> Result<Vec<Vec<T>>> {
157 let grouping_sets_size = left.len() * right.len();
158
159 check_grouping_sets_size_limit(grouping_sets_size)?;
160
161 let mut result = Vec::with_capacity(grouping_sets_size);
162 for le in left {
163 for re in right {
164 result.push(merge_grouping_set(le, re)?);
165 }
166 }
167 Ok(result)
168}
169
170pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
191 let has_grouping_set = group_expr
192 .iter()
193 .any(|expr| matches!(expr, Expr::GroupingSet(_)));
194 if !has_grouping_set || group_expr.len() == 1 {
195 return Ok(group_expr);
196 }
197 let partial_sets = group_expr
199 .iter()
200 .map(|expr| {
201 let exprs = match expr {
202 Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets)) => {
203 check_grouping_sets_size_limit(grouping_sets.len())?;
204 grouping_sets.iter().map(|e| e.iter().collect()).collect()
205 }
206 Expr::GroupingSet(GroupingSet::Cube(group_exprs)) => {
207 let grouping_sets = powerset(group_exprs)
208 .map_err(|e| plan_datafusion_err!("{}", e))?;
209 check_grouping_sets_size_limit(grouping_sets.len())?;
210 grouping_sets
211 }
212 Expr::GroupingSet(GroupingSet::Rollup(group_exprs)) => {
213 let size = group_exprs.len();
214 let slice = group_exprs.as_slice();
215 check_grouping_sets_size_limit(size * (size + 1) / 2 + 1)?;
216 (0..(size + 1))
217 .map(|i| slice[0..i].iter().collect())
218 .collect()
219 }
220 expr => vec![vec![expr]],
221 };
222 Ok(exprs)
223 })
224 .collect::<Result<Vec<_>>>()?;
225
226 let grouping_sets = partial_sets
228 .into_iter()
229 .map(Ok)
230 .reduce(|l, r| cross_join_grouping_sets(&l?, &r?))
231 .transpose()?
232 .map(|e| {
233 e.into_iter()
234 .map(|e| e.into_iter().cloned().collect())
235 .collect()
236 })
237 .unwrap_or_default();
238
239 Ok(vec![Expr::GroupingSet(GroupingSet::GroupingSets(
240 grouping_sets,
241 ))])
242}
243
244pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
247 if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
248 if group_expr.len() > 1 {
249 return plan_err!(
250 "Invalid group by expressions, GroupingSet must be the only expression"
251 );
252 }
253 Ok(grouping_set.distinct_expr())
254 } else {
255 Ok(group_expr
256 .iter()
257 .collect::<IndexSet<_>>()
258 .into_iter()
259 .collect())
260 }
261}
262
263pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
266 expr.apply(|expr| {
267 match expr {
268 Expr::Column(qc) => {
269 accum.insert(qc.clone());
270 }
271 #[expect(deprecated)]
276 Expr::Unnest(_)
277 | Expr::ScalarVariable(_, _)
278 | Expr::Alias(_)
279 | Expr::Literal(_)
280 | Expr::BinaryExpr { .. }
281 | Expr::Like { .. }
282 | Expr::SimilarTo { .. }
283 | Expr::Not(_)
284 | Expr::IsNotNull(_)
285 | Expr::IsNull(_)
286 | Expr::IsTrue(_)
287 | Expr::IsFalse(_)
288 | Expr::IsUnknown(_)
289 | Expr::IsNotTrue(_)
290 | Expr::IsNotFalse(_)
291 | Expr::IsNotUnknown(_)
292 | Expr::Negative(_)
293 | Expr::Between { .. }
294 | Expr::Case { .. }
295 | Expr::Cast { .. }
296 | Expr::TryCast { .. }
297 | Expr::ScalarFunction(..)
298 | Expr::WindowFunction { .. }
299 | Expr::AggregateFunction { .. }
300 | Expr::GroupingSet(_)
301 | Expr::InList { .. }
302 | Expr::Exists { .. }
303 | Expr::InSubquery(_)
304 | Expr::ScalarSubquery(_)
305 | Expr::Wildcard { .. }
306 | Expr::Placeholder(_)
307 | Expr::OuterReferenceColumn { .. } => {}
308 }
309 Ok(TreeNodeRecursion::Continue)
310 })
311 .map(|_| ())
312}
313
314fn get_excluded_columns(
317 opt_exclude: Option<&ExcludeSelectItem>,
318 opt_except: Option<&ExceptSelectItem>,
319 schema: &DFSchema,
320 qualifier: Option<&TableReference>,
321) -> Result<Vec<Column>> {
322 let mut idents = vec![];
323 if let Some(excepts) = opt_except {
324 idents.push(&excepts.first_element);
325 idents.extend(&excepts.additional_elements);
326 }
327 if let Some(exclude) = opt_exclude {
328 match exclude {
329 ExcludeSelectItem::Single(ident) => idents.push(ident),
330 ExcludeSelectItem::Multiple(idents_inner) => idents.extend(idents_inner),
331 }
332 }
333 let n_elem = idents.len();
335 let unique_idents = idents.into_iter().collect::<HashSet<_>>();
336 if n_elem != unique_idents.len() {
339 return plan_err!("EXCLUDE or EXCEPT contains duplicate column names");
340 }
341
342 let mut result = vec![];
343 for ident in unique_idents.into_iter() {
344 let col_name = ident.value.as_str();
345 let (qualifier, field) = schema.qualified_field_with_name(qualifier, col_name)?;
346 result.push(Column::from((qualifier, field)));
347 }
348 Ok(result)
349}
350
351fn get_exprs_except_skipped(
353 schema: &DFSchema,
354 columns_to_skip: HashSet<Column>,
355) -> Vec<Expr> {
356 if columns_to_skip.is_empty() {
357 schema.iter().map(Expr::from).collect::<Vec<Expr>>()
358 } else {
359 schema
360 .columns()
361 .iter()
362 .filter_map(|c| {
363 if !columns_to_skip.contains(c) {
364 Some(Expr::Column(c.clone()))
365 } else {
366 None
367 }
368 })
369 .collect::<Vec<Expr>>()
370 }
371}
372
373fn exclude_using_columns(plan: &LogicalPlan) -> Result<HashSet<Column>> {
377 let using_columns = plan.using_columns()?;
378 let excluded = using_columns
379 .into_iter()
380 .flat_map(|cols| {
382 let mut cols = cols.into_iter().collect::<Vec<_>>();
383 cols.sort();
386 let mut out_column_names: HashSet<String> = HashSet::new();
387 cols.into_iter().filter_map(move |c| {
388 if out_column_names.contains(&c.name) {
389 Some(c)
390 } else {
391 out_column_names.insert(c.name);
392 None
393 }
394 })
395 })
396 .collect::<HashSet<_>>();
397 Ok(excluded)
398}
399
400pub fn expand_wildcard(
402 schema: &DFSchema,
403 plan: &LogicalPlan,
404 wildcard_options: Option<&WildcardOptions>,
405) -> Result<Vec<Expr>> {
406 let mut columns_to_skip = exclude_using_columns(plan)?;
407 let excluded_columns = if let Some(WildcardOptions {
408 exclude: opt_exclude,
409 except: opt_except,
410 ..
411 }) = wildcard_options
412 {
413 get_excluded_columns(opt_exclude.as_ref(), opt_except.as_ref(), schema, None)?
414 } else {
415 vec![]
416 };
417 columns_to_skip.extend(excluded_columns);
419 Ok(get_exprs_except_skipped(schema, columns_to_skip))
420}
421
422pub fn expand_qualified_wildcard(
424 qualifier: &TableReference,
425 schema: &DFSchema,
426 wildcard_options: Option<&WildcardOptions>,
427) -> Result<Vec<Expr>> {
428 let qualified_indices = schema.fields_indices_with_qualified(qualifier);
429 let projected_func_dependencies = schema
430 .functional_dependencies()
431 .project_functional_dependencies(&qualified_indices, qualified_indices.len());
432 let fields_with_qualified = get_at_indices(schema.fields(), &qualified_indices)?;
433 if fields_with_qualified.is_empty() {
434 return plan_err!("Invalid qualifier {qualifier}");
435 }
436
437 let qualified_schema = Arc::new(Schema::new_with_metadata(
438 fields_with_qualified,
439 schema.metadata().clone(),
440 ));
441 let qualified_dfschema =
442 DFSchema::try_from_qualified_schema(qualifier.clone(), &qualified_schema)?
443 .with_functional_dependencies(projected_func_dependencies)?;
444 let excluded_columns = if let Some(WildcardOptions {
445 exclude: opt_exclude,
446 except: opt_except,
447 ..
448 }) = wildcard_options
449 {
450 get_excluded_columns(
451 opt_exclude.as_ref(),
452 opt_except.as_ref(),
453 schema,
454 Some(qualifier),
455 )?
456 } else {
457 vec![]
458 };
459 let mut columns_to_skip = HashSet::new();
461 columns_to_skip.extend(excluded_columns);
462 Ok(get_exprs_except_skipped(
463 &qualified_dfschema,
464 columns_to_skip,
465 ))
466}
467
468type WindowSortKey = Vec<(Sort, bool)>;
471
472pub fn generate_sort_key(
474 partition_by: &[Expr],
475 order_by: &[Sort],
476) -> Result<WindowSortKey> {
477 let normalized_order_by_keys = order_by
478 .iter()
479 .map(|e| {
480 let Sort { expr, .. } = e;
481 Sort::new(expr.clone(), true, false)
482 })
483 .collect::<Vec<_>>();
484
485 let mut final_sort_keys = vec![];
486 let mut is_partition_flag = vec![];
487 partition_by.iter().for_each(|e| {
488 let e = e.clone().sort(true, false);
491 if let Some(pos) = normalized_order_by_keys.iter().position(|key| key.eq(&e)) {
492 let order_by_key = &order_by[pos];
493 if !final_sort_keys.contains(order_by_key) {
494 final_sort_keys.push(order_by_key.clone());
495 is_partition_flag.push(true);
496 }
497 } else if !final_sort_keys.contains(&e) {
498 final_sort_keys.push(e);
499 is_partition_flag.push(true);
500 }
501 });
502
503 order_by.iter().for_each(|e| {
504 if !final_sort_keys.contains(e) {
505 final_sort_keys.push(e.clone());
506 is_partition_flag.push(false);
507 }
508 });
509 let res = final_sort_keys
510 .into_iter()
511 .zip(is_partition_flag)
512 .collect::<Vec<_>>();
513 Ok(res)
514}
515
516pub fn compare_sort_expr(
519 sort_expr_a: &Sort,
520 sort_expr_b: &Sort,
521 schema: &DFSchemaRef,
522) -> Ordering {
523 let Sort {
524 expr: expr_a,
525 asc: asc_a,
526 nulls_first: nulls_first_a,
527 } = sort_expr_a;
528
529 let Sort {
530 expr: expr_b,
531 asc: asc_b,
532 nulls_first: nulls_first_b,
533 } = sort_expr_b;
534
535 let ref_indexes_a = find_column_indexes_referenced_by_expr(expr_a, schema);
536 let ref_indexes_b = find_column_indexes_referenced_by_expr(expr_b, schema);
537 for (idx_a, idx_b) in ref_indexes_a.iter().zip(ref_indexes_b.iter()) {
538 match idx_a.cmp(idx_b) {
539 Ordering::Less => {
540 return Ordering::Less;
541 }
542 Ordering::Greater => {
543 return Ordering::Greater;
544 }
545 Ordering::Equal => {}
546 }
547 }
548 match ref_indexes_a.len().cmp(&ref_indexes_b.len()) {
549 Ordering::Less => return Ordering::Greater,
550 Ordering::Greater => {
551 return Ordering::Less;
552 }
553 Ordering::Equal => {}
554 }
555 match (asc_a, asc_b) {
556 (true, false) => {
557 return Ordering::Greater;
558 }
559 (false, true) => {
560 return Ordering::Less;
561 }
562 _ => {}
563 }
564 match (nulls_first_a, nulls_first_b) {
565 (true, false) => {
566 return Ordering::Less;
567 }
568 (false, true) => {
569 return Ordering::Greater;
570 }
571 _ => {}
572 }
573 Ordering::Equal
574}
575
576pub fn group_window_expr_by_sort_keys(
578 window_expr: impl IntoIterator<Item = Expr>,
579) -> Result<Vec<(WindowSortKey, Vec<Expr>)>> {
580 let mut result = vec![];
581 window_expr.into_iter().try_for_each(|expr| match &expr {
582 Expr::WindowFunction( WindowFunction{ params: WindowFunctionParams { partition_by, order_by, ..}, .. }) => {
583 let sort_key = generate_sort_key(partition_by, order_by)?;
584 if let Some((_, values)) = result.iter_mut().find(
585 |group: &&mut (WindowSortKey, Vec<Expr>)| matches!(group, (key, _) if *key == sort_key),
586 ) {
587 values.push(expr);
588 } else {
589 result.push((sort_key, vec![expr]))
590 }
591 Ok(())
592 }
593 other => internal_err!(
594 "Impossibly got non-window expr {other:?}"
595 ),
596 })?;
597 Ok(result)
598}
599
600pub fn find_aggregate_exprs<'a>(exprs: impl IntoIterator<Item = &'a Expr>) -> Vec<Expr> {
604 find_exprs_in_exprs(exprs, &|nested_expr| {
605 matches!(nested_expr, Expr::AggregateFunction { .. })
606 })
607}
608
609pub fn find_window_exprs(exprs: &[Expr]) -> Vec<Expr> {
612 find_exprs_in_exprs(exprs, &|nested_expr| {
613 matches!(nested_expr, Expr::WindowFunction { .. })
614 })
615}
616
617pub fn find_out_reference_exprs(expr: &Expr) -> Vec<Expr> {
620 find_exprs_in_expr(expr, &|nested_expr| {
621 matches!(nested_expr, Expr::OuterReferenceColumn { .. })
622 })
623}
624
625fn find_exprs_in_exprs<'a, F>(
629 exprs: impl IntoIterator<Item = &'a Expr>,
630 test_fn: &F,
631) -> Vec<Expr>
632where
633 F: Fn(&Expr) -> bool,
634{
635 exprs
636 .into_iter()
637 .flat_map(|expr| find_exprs_in_expr(expr, test_fn))
638 .fold(vec![], |mut acc, expr| {
639 if !acc.contains(&expr) {
640 acc.push(expr)
641 }
642 acc
643 })
644}
645
646fn find_exprs_in_expr<F>(expr: &Expr, test_fn: &F) -> Vec<Expr>
650where
651 F: Fn(&Expr) -> bool,
652{
653 let mut exprs = vec![];
654 expr.apply(|expr| {
655 if test_fn(expr) {
656 if !(exprs.contains(expr)) {
657 exprs.push(expr.clone())
658 }
659 return Ok(TreeNodeRecursion::Jump);
661 }
662
663 Ok(TreeNodeRecursion::Continue)
664 })
665 .expect("no way to return error during recursion");
667 exprs
668}
669
670pub fn inspect_expr_pre<F, E>(expr: &Expr, mut f: F) -> Result<(), E>
672where
673 F: FnMut(&Expr) -> Result<(), E>,
674{
675 let mut err = Ok(());
676 expr.apply(|expr| {
677 if let Err(e) = f(expr) {
678 err = Err(e);
680 Ok(TreeNodeRecursion::Stop)
681 } else {
682 Ok(TreeNodeRecursion::Continue)
684 }
685 })
686 .expect("no way to return error during recursion");
688
689 err
690}
691
692pub fn exprlist_to_fields<'a>(
694 exprs: impl IntoIterator<Item = &'a Expr>,
695 plan: &LogicalPlan,
696) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
697 let input_schema = plan.schema();
699 exprs
700 .into_iter()
701 .map(|e| e.to_field(input_schema))
702 .collect()
703}
704
705pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
721 let output_exprs = match input.columnized_output_exprs() {
722 Ok(exprs) if !exprs.is_empty() => exprs,
723 _ => return Ok(e),
724 };
725 let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
726 e.transform_down(|node: Expr| match exprs_map.get(&node) {
727 Some(column) => Ok(Transformed::new(
728 Expr::Column(column.clone()),
729 true,
730 TreeNodeRecursion::Jump,
731 )),
732 None => Ok(Transformed::no(node)),
733 })
734 .data()
735}
736
737pub fn find_column_exprs(exprs: &[Expr]) -> Vec<Expr> {
740 exprs
741 .iter()
742 .flat_map(find_columns_referenced_by_expr)
743 .map(Expr::Column)
744 .collect()
745}
746
747pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec<Column> {
748 let mut exprs = vec![];
749 e.apply(|expr| {
750 if let Expr::Column(c) = expr {
751 exprs.push(c.clone())
752 }
753 Ok(TreeNodeRecursion::Continue)
754 })
755 .expect("Unexpected error");
757 exprs
758}
759
760pub fn expr_as_column_expr(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
762 match expr {
763 Expr::Column(col) => {
764 let (qualifier, field) = plan.schema().qualified_field_from_column(col)?;
765 Ok(Expr::from(Column::from((qualifier, field))))
766 }
767 _ => Ok(Expr::Column(Column::from_name(
768 expr.schema_name().to_string(),
769 ))),
770 }
771}
772
773pub(crate) fn find_column_indexes_referenced_by_expr(
776 e: &Expr,
777 schema: &DFSchemaRef,
778) -> Vec<usize> {
779 let mut indexes = vec![];
780 e.apply(|expr| {
781 match expr {
782 Expr::Column(qc) => {
783 if let Ok(idx) = schema.index_of_column(qc) {
784 indexes.push(idx);
785 }
786 }
787 Expr::Literal(_) => {
788 indexes.push(usize::MAX);
789 }
790 _ => {}
791 }
792 Ok(TreeNodeRecursion::Continue)
793 })
794 .unwrap();
795 indexes
796}
797
798pub fn can_hash(data_type: &DataType) -> bool {
802 match data_type {
803 DataType::Null => true,
804 DataType::Boolean => true,
805 DataType::Int8 => true,
806 DataType::Int16 => true,
807 DataType::Int32 => true,
808 DataType::Int64 => true,
809 DataType::UInt8 => true,
810 DataType::UInt16 => true,
811 DataType::UInt32 => true,
812 DataType::UInt64 => true,
813 DataType::Float16 => true,
814 DataType::Float32 => true,
815 DataType::Float64 => true,
816 DataType::Decimal128(_, _) => true,
817 DataType::Decimal256(_, _) => true,
818 DataType::Timestamp(_, _) => true,
819 DataType::Utf8 => true,
820 DataType::LargeUtf8 => true,
821 DataType::Utf8View => true,
822 DataType::Binary => true,
823 DataType::LargeBinary => true,
824 DataType::BinaryView => true,
825 DataType::Date32 => true,
826 DataType::Date64 => true,
827 DataType::Time32(_) => true,
828 DataType::Time64(_) => true,
829 DataType::Duration(_) => true,
830 DataType::Interval(_) => true,
831 DataType::FixedSizeBinary(_) => true,
832 DataType::Dictionary(key_type, value_type) => {
833 DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
834 }
835 DataType::List(value_type) => can_hash(value_type.data_type()),
836 DataType::LargeList(value_type) => can_hash(value_type.data_type()),
837 DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
838 DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
839 DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
840
841 DataType::ListView(_)
842 | DataType::LargeListView(_)
843 | DataType::Union(_, _)
844 | DataType::RunEndEncoded(_, _) => false,
845 }
846}
847
848pub fn check_all_columns_from_schema(
850 columns: &HashSet<&Column>,
851 schema: &DFSchema,
852) -> Result<bool> {
853 for col in columns.iter() {
854 let exist = schema.is_column_from_schema(col);
855 if !exist {
856 return Ok(false);
857 }
858 }
859
860 Ok(true)
861}
862
863pub fn find_valid_equijoin_key_pair(
873 left_key: &Expr,
874 right_key: &Expr,
875 left_schema: &DFSchema,
876 right_schema: &DFSchema,
877) -> Result<Option<(Expr, Expr)>> {
878 let left_using_columns = left_key.column_refs();
879 let right_using_columns = right_key.column_refs();
880
881 if left_using_columns.is_empty() || right_using_columns.is_empty() {
883 return Ok(None);
884 }
885
886 if check_all_columns_from_schema(&left_using_columns, left_schema)?
887 && check_all_columns_from_schema(&right_using_columns, right_schema)?
888 {
889 return Ok(Some((left_key.clone(), right_key.clone())));
890 } else if check_all_columns_from_schema(&right_using_columns, left_schema)?
891 && check_all_columns_from_schema(&left_using_columns, right_schema)?
892 {
893 return Ok(Some((right_key.clone(), left_key.clone())));
894 }
895
896 Ok(None)
897}
898
899pub fn generate_signature_error_msg(
911 func_name: &str,
912 func_signature: Signature,
913 input_expr_types: &[DataType],
914) -> String {
915 let candidate_signatures = func_signature
916 .type_signature
917 .to_string_repr()
918 .iter()
919 .map(|args_str| format!("\t{func_name}({args_str})"))
920 .collect::<Vec<String>>()
921 .join("\n");
922
923 format!(
924 "No function matches the given name and argument types '{}({})'. You might need to add explicit type casts.\n\tCandidate functions:\n{}",
925 func_name, TypeSignature::join_types(input_expr_types, ", "), candidate_signatures
926 )
927}
928
929pub fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
933 split_conjunction_impl(expr, vec![])
934}
935
936fn split_conjunction_impl<'a>(expr: &'a Expr, mut exprs: Vec<&'a Expr>) -> Vec<&'a Expr> {
937 match expr {
938 Expr::BinaryExpr(BinaryExpr {
939 right,
940 op: Operator::And,
941 left,
942 }) => {
943 let exprs = split_conjunction_impl(left, exprs);
944 split_conjunction_impl(right, exprs)
945 }
946 Expr::Alias(Alias { expr, .. }) => split_conjunction_impl(expr, exprs),
947 other => {
948 exprs.push(other);
949 exprs
950 }
951 }
952}
953
954pub fn iter_conjunction(expr: &Expr) -> impl Iterator<Item = &Expr> {
958 let mut stack = vec![expr];
959 std::iter::from_fn(move || {
960 while let Some(expr) = stack.pop() {
961 match expr {
962 Expr::BinaryExpr(BinaryExpr {
963 right,
964 op: Operator::And,
965 left,
966 }) => {
967 stack.push(right);
968 stack.push(left);
969 }
970 Expr::Alias(Alias { expr, .. }) => stack.push(expr),
971 other => return Some(other),
972 }
973 }
974 None
975 })
976}
977
978pub fn iter_conjunction_owned(expr: Expr) -> impl Iterator<Item = Expr> {
982 let mut stack = vec![expr];
983 std::iter::from_fn(move || {
984 while let Some(expr) = stack.pop() {
985 match expr {
986 Expr::BinaryExpr(BinaryExpr {
987 right,
988 op: Operator::And,
989 left,
990 }) => {
991 stack.push(*right);
992 stack.push(*left);
993 }
994 Expr::Alias(Alias { expr, .. }) => stack.push(*expr),
995 other => return Some(other),
996 }
997 }
998 None
999 })
1000}
1001
1002pub fn split_conjunction_owned(expr: Expr) -> Vec<Expr> {
1024 split_binary_owned(expr, Operator::And)
1025}
1026
1027pub fn split_binary_owned(expr: Expr, op: Operator) -> Vec<Expr> {
1050 split_binary_owned_impl(expr, op, vec![])
1051}
1052
1053fn split_binary_owned_impl(
1054 expr: Expr,
1055 operator: Operator,
1056 mut exprs: Vec<Expr>,
1057) -> Vec<Expr> {
1058 match expr {
1059 Expr::BinaryExpr(BinaryExpr { right, op, left }) if op == operator => {
1060 let exprs = split_binary_owned_impl(*left, operator, exprs);
1061 split_binary_owned_impl(*right, operator, exprs)
1062 }
1063 Expr::Alias(Alias { expr, .. }) => {
1064 split_binary_owned_impl(*expr, operator, exprs)
1065 }
1066 other => {
1067 exprs.push(other);
1068 exprs
1069 }
1070 }
1071}
1072
1073pub fn split_binary(expr: &Expr, op: Operator) -> Vec<&Expr> {
1077 split_binary_impl(expr, op, vec![])
1078}
1079
1080fn split_binary_impl<'a>(
1081 expr: &'a Expr,
1082 operator: Operator,
1083 mut exprs: Vec<&'a Expr>,
1084) -> Vec<&'a Expr> {
1085 match expr {
1086 Expr::BinaryExpr(BinaryExpr { right, op, left }) if *op == operator => {
1087 let exprs = split_binary_impl(left, operator, exprs);
1088 split_binary_impl(right, operator, exprs)
1089 }
1090 Expr::Alias(Alias { expr, .. }) => split_binary_impl(expr, operator, exprs),
1091 other => {
1092 exprs.push(other);
1093 exprs
1094 }
1095 }
1096}
1097
1098pub fn conjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1121 filters.into_iter().reduce(Expr::and)
1122}
1123
1124pub fn disjunction(filters: impl IntoIterator<Item = Expr>) -> Option<Expr> {
1147 filters.into_iter().reduce(Expr::or)
1148}
1149
1150pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result<LogicalPlan> {
1165 let predicate = predicates
1167 .iter()
1168 .skip(1)
1169 .fold(predicates[0].clone(), |acc, predicate| {
1170 and(acc, (*predicate).to_owned())
1171 });
1172
1173 Ok(LogicalPlan::Filter(Filter::try_new(
1174 predicate,
1175 Arc::new(plan),
1176 )?))
1177}
1178
1179pub fn find_join_exprs(exprs: Vec<&Expr>) -> Result<(Vec<Expr>, Vec<Expr>)> {
1190 let mut joins = vec![];
1191 let mut others = vec![];
1192 for filter in exprs.into_iter() {
1193 if filter.contains_outer() {
1195 if !matches!(filter, Expr::BinaryExpr(BinaryExpr{ left, op: Operator::Eq, right }) if left.eq(right))
1196 {
1197 joins.push(strip_outer_reference((*filter).clone()));
1198 }
1199 } else {
1200 others.push((*filter).clone());
1201 }
1202 }
1203
1204 Ok((joins, others))
1205}
1206
1207pub fn only_or_err<T>(slice: &[T]) -> Result<&T> {
1217 match slice {
1218 [it] => Ok(it),
1219 [] => plan_err!("No items found!"),
1220 _ => plan_err!("More than one item found!"),
1221 }
1222}
1223
1224pub fn merge_schema(inputs: &[&LogicalPlan]) -> DFSchema {
1226 if inputs.len() == 1 {
1227 inputs[0].schema().as_ref().clone()
1228 } else {
1229 inputs.iter().map(|input| input.schema()).fold(
1230 DFSchema::empty(),
1231 |mut lhs, rhs| {
1232 lhs.merge(rhs);
1233 lhs
1234 },
1235 )
1236 }
1237}
1238
1239pub fn format_state_name(name: &str, state_name: &str) -> String {
1241 format!("{name}[{state_name}]")
1242}
1243
1244pub fn collect_subquery_cols(
1246 exprs: &[Expr],
1247 subquery_schema: &DFSchema,
1248) -> Result<BTreeSet<Column>> {
1249 exprs.iter().try_fold(BTreeSet::new(), |mut cols, expr| {
1250 let mut using_cols: Vec<Column> = vec![];
1251 for col in expr.column_refs().into_iter() {
1252 if subquery_schema.has_column(col) {
1253 using_cols.push(col.clone());
1254 }
1255 }
1256
1257 cols.extend(using_cols);
1258 Result::<_>::Ok(cols)
1259 })
1260}
1261
1262#[cfg(test)]
1263mod tests {
1264 use super::*;
1265 use crate::{
1266 col, cube, expr_vec_fmt, grouping_set, lit, rollup,
1267 test::function_stub::max_udaf, test::function_stub::min_udaf,
1268 test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition,
1269 };
1270 use arrow::datatypes::{UnionFields, UnionMode};
1271
1272 #[test]
1273 fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
1274 let result = group_window_expr_by_sort_keys(vec![])?;
1275 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![];
1276 assert_eq!(expected, result);
1277 Ok(())
1278 }
1279
1280 #[test]
1281 fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> {
1282 let max1 = Expr::WindowFunction(WindowFunction::new(
1283 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1284 vec![col("name")],
1285 ));
1286 let max2 = Expr::WindowFunction(WindowFunction::new(
1287 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1288 vec![col("name")],
1289 ));
1290 let min3 = Expr::WindowFunction(WindowFunction::new(
1291 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1292 vec![col("name")],
1293 ));
1294 let sum4 = Expr::WindowFunction(WindowFunction::new(
1295 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1296 vec![col("age")],
1297 ));
1298 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1299 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1300 let key = vec![];
1301 let expected: Vec<(WindowSortKey, Vec<Expr>)> =
1302 vec![(key, vec![max1, max2, min3, sum4])];
1303 assert_eq!(expected, result);
1304 Ok(())
1305 }
1306
1307 #[test]
1308 fn test_group_window_expr_by_sort_keys() -> Result<()> {
1309 let age_asc = Sort::new(col("age"), true, true);
1310 let name_desc = Sort::new(col("name"), false, true);
1311 let created_at_desc = Sort::new(col("created_at"), false, true);
1312 let max1 = Expr::WindowFunction(WindowFunction::new(
1313 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1314 vec![col("name")],
1315 ))
1316 .order_by(vec![age_asc.clone(), name_desc.clone()])
1317 .build()
1318 .unwrap();
1319 let max2 = Expr::WindowFunction(WindowFunction::new(
1320 WindowFunctionDefinition::AggregateUDF(max_udaf()),
1321 vec![col("name")],
1322 ));
1323 let min3 = Expr::WindowFunction(WindowFunction::new(
1324 WindowFunctionDefinition::AggregateUDF(min_udaf()),
1325 vec![col("name")],
1326 ))
1327 .order_by(vec![age_asc.clone(), name_desc.clone()])
1328 .build()
1329 .unwrap();
1330 let sum4 = Expr::WindowFunction(WindowFunction::new(
1331 WindowFunctionDefinition::AggregateUDF(sum_udaf()),
1332 vec![col("age")],
1333 ))
1334 .order_by(vec![
1335 name_desc.clone(),
1336 age_asc.clone(),
1337 created_at_desc.clone(),
1338 ])
1339 .build()
1340 .unwrap();
1341 let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()];
1343 let result = group_window_expr_by_sort_keys(exprs.to_vec())?;
1344
1345 let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)];
1346 let key2 = vec![];
1347 let key3 = vec![
1348 (name_desc, false),
1349 (age_asc, false),
1350 (created_at_desc, false),
1351 ];
1352
1353 let expected: Vec<(WindowSortKey, Vec<Expr>)> = vec![
1354 (key1, vec![max1, min3]),
1355 (key2, vec![max2]),
1356 (key3, vec![sum4]),
1357 ];
1358 assert_eq!(expected, result);
1359 Ok(())
1360 }
1361
1362 #[test]
1363 fn avoid_generate_duplicate_sort_keys() -> Result<()> {
1364 let asc_or_desc = [true, false];
1365 let nulls_first_or_last = [true, false];
1366 let partition_by = &[col("age"), col("name"), col("created_at")];
1367 for asc_ in asc_or_desc {
1368 for nulls_first_ in nulls_first_or_last {
1369 let order_by = &[
1370 Sort {
1371 expr: col("age"),
1372 asc: asc_,
1373 nulls_first: nulls_first_,
1374 },
1375 Sort {
1376 expr: col("name"),
1377 asc: asc_,
1378 nulls_first: nulls_first_,
1379 },
1380 ];
1381
1382 let expected = vec![
1383 (
1384 Sort {
1385 expr: col("age"),
1386 asc: asc_,
1387 nulls_first: nulls_first_,
1388 },
1389 true,
1390 ),
1391 (
1392 Sort {
1393 expr: col("name"),
1394 asc: asc_,
1395 nulls_first: nulls_first_,
1396 },
1397 true,
1398 ),
1399 (
1400 Sort {
1401 expr: col("created_at"),
1402 asc: true,
1403 nulls_first: false,
1404 },
1405 true,
1406 ),
1407 ];
1408 let result = generate_sort_key(partition_by, order_by)?;
1409 assert_eq!(expected, result);
1410 }
1411 }
1412 Ok(())
1413 }
1414
1415 #[test]
1416 fn test_enumerate_grouping_sets() -> Result<()> {
1417 let multi_cols = vec![col("col1"), col("col2"), col("col3")];
1418 let simple_col = col("simple_col");
1419 let cube = cube(multi_cols.clone());
1420 let rollup = rollup(multi_cols.clone());
1421 let grouping_set = grouping_set(vec![multi_cols]);
1422
1423 let sets = enumerate_grouping_sets(vec![simple_col.clone()])?;
1425 let result = format!("[{}]", expr_vec_fmt!(sets));
1426 assert_eq!("[simple_col]", &result);
1427
1428 let sets = enumerate_grouping_sets(vec![cube.clone()])?;
1430 let result = format!("[{}]", expr_vec_fmt!(sets));
1431 assert_eq!("[CUBE (col1, col2, col3)]", &result);
1432
1433 let sets = enumerate_grouping_sets(vec![rollup.clone()])?;
1435 let result = format!("[{}]", expr_vec_fmt!(sets));
1436 assert_eq!("[ROLLUP (col1, col2, col3)]", &result);
1437
1438 let sets = enumerate_grouping_sets(vec![simple_col.clone(), cube.clone()])?;
1440 let result = format!("[{}]", expr_vec_fmt!(sets));
1441 assert_eq!(
1442 "[GROUPING SETS (\
1443 (simple_col), \
1444 (simple_col, col1), \
1445 (simple_col, col2), \
1446 (simple_col, col1, col2), \
1447 (simple_col, col3), \
1448 (simple_col, col1, col3), \
1449 (simple_col, col2, col3), \
1450 (simple_col, col1, col2, col3))]",
1451 &result
1452 );
1453
1454 let sets = enumerate_grouping_sets(vec![simple_col.clone(), rollup.clone()])?;
1456 let result = format!("[{}]", expr_vec_fmt!(sets));
1457 assert_eq!(
1458 "[GROUPING SETS (\
1459 (simple_col), \
1460 (simple_col, col1), \
1461 (simple_col, col1, col2), \
1462 (simple_col, col1, col2, col3))]",
1463 &result
1464 );
1465
1466 let sets =
1468 enumerate_grouping_sets(vec![simple_col.clone(), grouping_set.clone()])?;
1469 let result = format!("[{}]", expr_vec_fmt!(sets));
1470 assert_eq!(
1471 "[GROUPING SETS (\
1472 (simple_col, col1, col2, col3))]",
1473 &result
1474 );
1475
1476 let sets = enumerate_grouping_sets(vec![
1478 simple_col.clone(),
1479 grouping_set,
1480 rollup.clone(),
1481 ])?;
1482 let result = format!("[{}]", expr_vec_fmt!(sets));
1483 assert_eq!(
1484 "[GROUPING SETS (\
1485 (simple_col, col1, col2, col3), \
1486 (simple_col, col1, col2, col3, col1), \
1487 (simple_col, col1, col2, col3, col1, col2), \
1488 (simple_col, col1, col2, col3, col1, col2, col3))]",
1489 &result
1490 );
1491
1492 let sets = enumerate_grouping_sets(vec![simple_col, cube, rollup])?;
1494 let result = format!("[{}]", expr_vec_fmt!(sets));
1495 assert_eq!(
1496 "[GROUPING SETS (\
1497 (simple_col), \
1498 (simple_col, col1), \
1499 (simple_col, col1, col2), \
1500 (simple_col, col1, col2, col3), \
1501 (simple_col, col1), \
1502 (simple_col, col1, col1), \
1503 (simple_col, col1, col1, col2), \
1504 (simple_col, col1, col1, col2, col3), \
1505 (simple_col, col2), \
1506 (simple_col, col2, col1), \
1507 (simple_col, col2, col1, col2), \
1508 (simple_col, col2, col1, col2, col3), \
1509 (simple_col, col1, col2), \
1510 (simple_col, col1, col2, col1), \
1511 (simple_col, col1, col2, col1, col2), \
1512 (simple_col, col1, col2, col1, col2, col3), \
1513 (simple_col, col3), \
1514 (simple_col, col3, col1), \
1515 (simple_col, col3, col1, col2), \
1516 (simple_col, col3, col1, col2, col3), \
1517 (simple_col, col1, col3), \
1518 (simple_col, col1, col3, col1), \
1519 (simple_col, col1, col3, col1, col2), \
1520 (simple_col, col1, col3, col1, col2, col3), \
1521 (simple_col, col2, col3), \
1522 (simple_col, col2, col3, col1), \
1523 (simple_col, col2, col3, col1, col2), \
1524 (simple_col, col2, col3, col1, col2, col3), \
1525 (simple_col, col1, col2, col3), \
1526 (simple_col, col1, col2, col3, col1), \
1527 (simple_col, col1, col2, col3, col1, col2), \
1528 (simple_col, col1, col2, col3, col1, col2, col3))]",
1529 &result
1530 );
1531
1532 Ok(())
1533 }
1534 #[test]
1535 fn test_split_conjunction() {
1536 let expr = col("a");
1537 let result = split_conjunction(&expr);
1538 assert_eq!(result, vec![&expr]);
1539 }
1540
1541 #[test]
1542 fn test_split_conjunction_two() {
1543 let expr = col("a").eq(lit(5)).and(col("b"));
1544 let expr1 = col("a").eq(lit(5));
1545 let expr2 = col("b");
1546
1547 let result = split_conjunction(&expr);
1548 assert_eq!(result, vec![&expr1, &expr2]);
1549 }
1550
1551 #[test]
1552 fn test_split_conjunction_alias() {
1553 let expr = col("a").eq(lit(5)).and(col("b").alias("the_alias"));
1554 let expr1 = col("a").eq(lit(5));
1555 let expr2 = col("b"); let result = split_conjunction(&expr);
1558 assert_eq!(result, vec![&expr1, &expr2]);
1559 }
1560
1561 #[test]
1562 fn test_split_conjunction_or() {
1563 let expr = col("a").eq(lit(5)).or(col("b"));
1564 let result = split_conjunction(&expr);
1565 assert_eq!(result, vec![&expr]);
1566 }
1567
1568 #[test]
1569 fn test_split_binary_owned() {
1570 let expr = col("a");
1571 assert_eq!(split_binary_owned(expr.clone(), Operator::And), vec![expr]);
1572 }
1573
1574 #[test]
1575 fn test_split_binary_owned_two() {
1576 assert_eq!(
1577 split_binary_owned(col("a").eq(lit(5)).and(col("b")), Operator::And),
1578 vec![col("a").eq(lit(5)), col("b")]
1579 );
1580 }
1581
1582 #[test]
1583 fn test_split_binary_owned_different_op() {
1584 let expr = col("a").eq(lit(5)).or(col("b"));
1585 assert_eq!(
1586 split_binary_owned(expr.clone(), Operator::And),
1588 vec![expr]
1589 );
1590 }
1591
1592 #[test]
1593 fn test_split_conjunction_owned() {
1594 let expr = col("a");
1595 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1596 }
1597
1598 #[test]
1599 fn test_split_conjunction_owned_two() {
1600 assert_eq!(
1601 split_conjunction_owned(col("a").eq(lit(5)).and(col("b"))),
1602 vec![col("a").eq(lit(5)), col("b")]
1603 );
1604 }
1605
1606 #[test]
1607 fn test_split_conjunction_owned_alias() {
1608 assert_eq!(
1609 split_conjunction_owned(col("a").eq(lit(5)).and(col("b").alias("the_alias"))),
1610 vec![
1611 col("a").eq(lit(5)),
1612 col("b"),
1614 ]
1615 );
1616 }
1617
1618 #[test]
1619 fn test_conjunction_empty() {
1620 assert_eq!(conjunction(vec![]), None);
1621 }
1622
1623 #[test]
1624 fn test_conjunction() {
1625 let expr = conjunction(vec![col("a"), col("b"), col("c")]);
1627
1628 assert_eq!(expr, Some(col("a").and(col("b")).and(col("c"))));
1630
1631 assert_ne!(expr, Some(col("a").and(col("b").and(col("c")))));
1633 }
1634
1635 #[test]
1636 fn test_disjunction_empty() {
1637 assert_eq!(disjunction(vec![]), None);
1638 }
1639
1640 #[test]
1641 fn test_disjunction() {
1642 let expr = disjunction(vec![col("a"), col("b"), col("c")]);
1644
1645 assert_eq!(expr, Some(col("a").or(col("b")).or(col("c"))));
1647
1648 assert_ne!(expr, Some(col("a").or(col("b").or(col("c")))));
1650 }
1651
1652 #[test]
1653 fn test_split_conjunction_owned_or() {
1654 let expr = col("a").eq(lit(5)).or(col("b"));
1655 assert_eq!(split_conjunction_owned(expr.clone()), vec![expr]);
1656 }
1657
1658 #[test]
1659 fn test_collect_expr() -> Result<()> {
1660 let mut accum: HashSet<Column> = HashSet::new();
1661 expr_to_columns(
1662 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1663 &mut accum,
1664 )?;
1665 expr_to_columns(
1666 &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64)),
1667 &mut accum,
1668 )?;
1669 assert_eq!(1, accum.len());
1670 assert!(accum.contains(&Column::from_name("a")));
1671 Ok(())
1672 }
1673
1674 #[test]
1675 fn test_can_hash() {
1676 let union_fields: UnionFields = [
1677 (0, Arc::new(Field::new("A", DataType::Int32, true))),
1678 (1, Arc::new(Field::new("B", DataType::Float64, true))),
1679 ]
1680 .into_iter()
1681 .collect();
1682
1683 let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1684 assert!(!can_hash(&union_type));
1685
1686 let list_union_type =
1687 DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1688 assert!(!can_hash(&list_union_type));
1689 }
1690}