1use std::vec;
21
22use arrow::datatypes::{
23 DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
24};
25use datafusion_common::tree_node::{
26 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29 exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef,
30 Diagnostic, HashMap, Result, ScalarValue,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34 Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38 col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46 expr.clone()
47 .transform_up(|nested_expr| {
48 match nested_expr {
49 Expr::Column(col) => {
50 let (qualifier, field) =
51 plan.schema().qualified_field_from_column(&col)?;
52 Ok(Transformed::yes(Expr::Column(Column::from((
53 qualifier, field,
54 )))))
55 }
56 _ => {
57 Ok(Transformed::no(nested_expr))
59 }
60 }
61 })
62 .data()
63}
64
65pub(crate) fn rebase_expr(
80 expr: &Expr,
81 base_exprs: &[Expr],
82 plan: &LogicalPlan,
83) -> Result<Expr> {
84 expr.clone()
85 .transform_down(|nested_expr| {
86 if base_exprs.contains(&nested_expr) {
87 Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88 } else {
89 Ok(Transformed::no(nested_expr))
90 }
91 })
92 .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
97 Projection,
98 Having,
99 Qualify,
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
103pub(crate) enum CheckColumnsSatisfyExprsPurpose {
104 Aggregate(CheckColumnsMustReferenceAggregatePurpose),
105}
106
107impl CheckColumnsSatisfyExprsPurpose {
108 fn message_prefix(&self) -> &'static str {
109 match self {
110 Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
111 "Column in SELECT must be in GROUP BY or an aggregate function"
112 }
113 Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
114 "Column in HAVING must be in GROUP BY or an aggregate function"
115 }
116 Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
117 "Column in QUALIFY must be in GROUP BY or an aggregate function"
118 }
119 }
120 }
121
122 fn diagnostic_message(&self, expr: &Expr) -> String {
123 format!("'{expr}' must appear in GROUP BY clause because it's not an aggregate expression")
124 }
125}
126
127pub(crate) fn check_columns_satisfy_exprs(
130 columns: &[Expr],
131 exprs: &[Expr],
132 purpose: CheckColumnsSatisfyExprsPurpose,
133) -> Result<()> {
134 columns.iter().try_for_each(|c| match c {
135 Expr::Column(_) => Ok(()),
136 _ => internal_err!("Expr::Column are required"),
137 })?;
138 let column_exprs = find_column_exprs(exprs);
139 for e in &column_exprs {
140 match e {
141 Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
142 for e in exprs {
143 check_column_satisfies_expr(columns, e, purpose)?;
144 }
145 }
146 Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
147 for e in exprs {
148 check_column_satisfies_expr(columns, e, purpose)?;
149 }
150 }
151 Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
152 for exprs in lists_of_exprs {
153 for e in exprs {
154 check_column_satisfies_expr(columns, e, purpose)?;
155 }
156 }
157 }
158 _ => check_column_satisfies_expr(columns, e, purpose)?,
159 }
160 }
161 Ok(())
162}
163
164fn check_column_satisfies_expr(
165 columns: &[Expr],
166 expr: &Expr,
167 purpose: CheckColumnsSatisfyExprsPurpose,
168) -> Result<()> {
169 if !columns.contains(expr) {
170 let diagnostic = Diagnostic::new_error(
171 purpose.diagnostic_message(expr),
172 expr.spans().and_then(|spans| spans.first()),
173 )
174 .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
175
176 return plan_err!(
177 "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
178 purpose.message_prefix(),
179 expr,
180 expr_vec_fmt!(columns);
181 diagnostic=diagnostic
182 );
183 }
184 Ok(())
185}
186
187pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
190 exprs
191 .iter()
192 .filter_map(|expr| match expr {
193 Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
194 _ => None,
195 })
196 .collect::<HashMap<String, Expr>>()
197}
198
199pub(crate) fn resolve_positions_to_exprs(
204 expr: Expr,
205 select_exprs: &[Expr],
206) -> Result<Expr> {
207 match expr {
208 Expr::Literal(ScalarValue::Int64(Some(position)), _)
211 if position > 0_i64 && position <= select_exprs.len() as i64 =>
212 {
213 let index = (position - 1) as usize;
214 let select_expr = &select_exprs[index];
215 Ok(match select_expr {
216 Expr::Alias(Alias { expr, .. }) => *expr.clone(),
217 _ => select_expr.clone(),
218 })
219 }
220 Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
221 "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
222 position, select_exprs.len()
223 ),
224 _ => Ok(expr),
225 }
226}
227
228pub(crate) fn resolve_aliases_to_exprs(
231 expr: Expr,
232 aliases: &HashMap<String, Expr>,
233) -> Result<Expr> {
234 expr.transform_up(|nested_expr| match nested_expr {
235 Expr::Column(c) if c.relation.is_none() => {
236 if let Some(aliased_expr) = aliases.get(&c.name) {
237 Ok(Transformed::yes(aliased_expr.clone()))
238 } else {
239 Ok(Transformed::no(Expr::Column(c)))
240 }
241 }
242 _ => Ok(Transformed::no(nested_expr)),
243 })
244 .data()
245}
246
247pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
250 let all_partition_keys = window_exprs
251 .iter()
252 .map(|expr| match expr {
253 Expr::WindowFunction(window_fun) => {
254 let WindowFunction {
255 params: WindowFunctionParams { partition_by, .. },
256 ..
257 } = window_fun.as_ref();
258 Ok(partition_by)
259 }
260 Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
261 Expr::WindowFunction(window_fun) => {
262 let WindowFunction {
263 params: WindowFunctionParams { partition_by, .. },
264 ..
265 } = window_fun.as_ref();
266 Ok(partition_by)
267 }
268 expr => exec_err!("Impossibly got non-window expr {expr:?}"),
269 },
270 expr => exec_err!("Impossibly got non-window expr {expr:?}"),
271 })
272 .collect::<Result<Vec<_>>>()?;
273 let result = all_partition_keys
274 .iter()
275 .min_by_key(|s| s.len())
276 .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
277 Ok(result)
278}
279
280pub(crate) fn make_decimal_type(
283 precision: Option<u64>,
284 scale: Option<u64>,
285) -> Result<DataType> {
286 let (precision, scale) = match (precision, scale) {
288 (Some(p), Some(s)) => (p as u8, s as i8),
289 (Some(p), None) => (p as u8, 0),
290 (None, Some(_)) => {
291 return plan_err!("Cannot specify only scale for decimal data type")
292 }
293 (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
294 };
295
296 if precision == 0
297 || precision > DECIMAL256_MAX_PRECISION
298 || scale.unsigned_abs() > precision
299 {
300 plan_err!(
301 "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
302 )
303 } else if precision > DECIMAL128_MAX_PRECISION
304 && precision <= DECIMAL256_MAX_PRECISION
305 {
306 Ok(DataType::Decimal256(precision, scale))
307 } else {
308 Ok(DataType::Decimal128(precision, scale))
309 }
310}
311
312pub(crate) fn normalize_ident(id: Ident) -> String {
314 match id.quote_style {
315 Some(_) => id.value,
316 None => id.value.to_ascii_lowercase(),
317 }
318}
319
320pub(crate) fn value_to_string(value: &Value) -> Option<String> {
321 match value {
322 Value::SingleQuotedString(s) => Some(s.to_string()),
323 Value::DollarQuotedString(s) => Some(s.to_string()),
324 Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
325 Value::UnicodeStringLiteral(s) => Some(s.to_string()),
326 Value::EscapedStringLiteral(s) => Some(s.to_string()),
327 Value::DoubleQuotedString(_)
328 | Value::NationalStringLiteral(_)
329 | Value::SingleQuotedByteStringLiteral(_)
330 | Value::DoubleQuotedByteStringLiteral(_)
331 | Value::TripleSingleQuotedString(_)
332 | Value::TripleDoubleQuotedString(_)
333 | Value::TripleSingleQuotedByteStringLiteral(_)
334 | Value::TripleDoubleQuotedByteStringLiteral(_)
335 | Value::SingleQuotedRawStringLiteral(_)
336 | Value::DoubleQuotedRawStringLiteral(_)
337 | Value::TripleSingleQuotedRawStringLiteral(_)
338 | Value::TripleDoubleQuotedRawStringLiteral(_)
339 | Value::HexStringLiteral(_)
340 | Value::Null
341 | Value::Placeholder(_) => None,
342 }
343}
344
345pub(crate) fn rewrite_recursive_unnests_bottom_up(
346 input: &LogicalPlan,
347 unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
348 inner_projection_exprs: &mut Vec<Expr>,
349 original_exprs: &[Expr],
350) -> Result<Vec<Expr>> {
351 Ok(original_exprs
352 .iter()
353 .map(|expr| {
354 rewrite_recursive_unnest_bottom_up(
355 input,
356 unnest_placeholder_columns,
357 inner_projection_exprs,
358 expr,
359 )
360 })
361 .collect::<Result<Vec<_>>>()?
362 .into_iter()
363 .flatten()
364 .collect::<Vec<_>>())
365}
366
367pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
368
369struct RecursiveUnnestRewriter<'a> {
374 input_schema: &'a DFSchemaRef,
375 root_expr: &'a Expr,
376 top_most_unnest: Option<Unnest>,
378 consecutive_unnest: Vec<Option<Unnest>>,
379 inner_projection_exprs: &'a mut Vec<Expr>,
380 columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
381 transformed_root_exprs: Option<Vec<Expr>>,
382}
383impl RecursiveUnnestRewriter<'_> {
384 fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
391 self.consecutive_unnest
392 .iter()
393 .rev()
394 .skip_while(|item| item.is_none())
395 .take_while(|item| item.is_some())
396 .to_owned()
397 .cloned()
398 .map(|item| item.unwrap())
399 .collect()
400 }
401
402 fn transform(
403 &mut self,
404 level: usize,
405 alias_name: String,
406 expr_in_unnest: &Expr,
407 struct_allowed: bool,
408 ) -> Result<Vec<Expr>> {
409 let inner_expr_name = expr_in_unnest.schema_name().to_string();
410
411 let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
415 let post_unnest_name =
416 format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
417 let placeholder_column = Column::from_name(placeholder_name.clone());
420
421 let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?;
422
423 match data_type {
424 DataType::Struct(inner_fields) => {
425 if !struct_allowed {
426 return internal_err!("unnest on struct can only be applied at the root level of select expression");
427 }
428 push_projection_dedupl(
429 self.inner_projection_exprs,
430 expr_in_unnest.clone().alias(placeholder_name.clone()),
431 );
432 self.columns_unnestings
433 .insert(Column::from_name(placeholder_name.clone()), None);
434 Ok(
435 get_struct_unnested_columns(&placeholder_name, &inner_fields)
436 .into_iter()
437 .map(Expr::Column)
438 .collect(),
439 )
440 }
441 DataType::List(_)
442 | DataType::FixedSizeList(_, _)
443 | DataType::LargeList(_) => {
444 push_projection_dedupl(
445 self.inner_projection_exprs,
446 expr_in_unnest.clone().alias(placeholder_name.clone()),
447 );
448
449 let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
450 let list_unnesting = self
451 .columns_unnestings
452 .entry(placeholder_column)
453 .or_insert(Some(vec![]));
454 let unnesting = ColumnUnnestList {
455 output_column: Column::from_name(post_unnest_name),
456 depth: level,
457 };
458 let list_unnestings = list_unnesting.as_mut().unwrap();
459 if !list_unnestings.contains(&unnesting) {
460 list_unnestings.push(unnesting);
461 }
462 Ok(vec![post_unnest_expr])
463 }
464 _ => {
465 internal_err!("unnest on non-list or struct type is not supported")
466 }
467 }
468 }
469}
470
471impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
472 type Node = Expr;
473
474 fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
479 if let Expr::Unnest(ref unnest_expr) = expr {
480 let (data_type, _) =
481 unnest_expr.expr.data_type_and_nullable(self.input_schema)?;
482 self.consecutive_unnest.push(Some(unnest_expr.clone()));
483 if let DataType::Struct(_) = data_type {
493 self.consecutive_unnest.push(None);
494 }
495 if self.top_most_unnest.is_none() {
496 self.top_most_unnest = Some(unnest_expr.clone());
497 }
498
499 Ok(Transformed::no(expr))
500 } else {
501 self.consecutive_unnest.push(None);
502 Ok(Transformed::no(expr))
503 }
504 }
505
506 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
535 if let Expr::Unnest(ref traversing_unnest) = expr {
536 if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
537 self.top_most_unnest = None;
538 }
539 let unnest_stack = self.get_latest_consecutive_unnest();
547
548 if traversing_unnest == unnest_stack.last().unwrap() {
554 let most_inner = unnest_stack.first().unwrap();
555 let inner_expr = most_inner.expr.as_ref();
556 let unnest_recursion = unnest_stack.len();
563 let struct_allowed = (&expr == self.root_expr) && unnest_recursion == 1;
564
565 let mut transformed_exprs = self.transform(
566 unnest_recursion,
567 expr.schema_name().to_string(),
568 inner_expr,
569 struct_allowed,
570 )?;
571 if struct_allowed {
572 self.transformed_root_exprs = Some(transformed_exprs.clone());
573 }
574 return Ok(Transformed::new(
575 transformed_exprs.swap_remove(0),
576 true,
577 TreeNodeRecursion::Continue,
578 ));
579 }
580 } else {
581 self.consecutive_unnest.push(None);
582 }
583
584 if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
589 push_projection_dedupl(self.inner_projection_exprs, expr.clone());
590 }
591
592 Ok(Transformed::no(expr))
593 }
594}
595
596fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
597 let schema_name = expr.schema_name().to_string();
598 if !projection
599 .iter()
600 .any(|e| e.schema_name().to_string() == schema_name)
601 {
602 projection.push(expr);
603 }
604}
605pub(crate) fn rewrite_recursive_unnest_bottom_up(
615 input: &LogicalPlan,
616 unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
617 inner_projection_exprs: &mut Vec<Expr>,
618 original_expr: &Expr,
619) -> Result<Vec<Expr>> {
620 let mut rewriter = RecursiveUnnestRewriter {
621 input_schema: input.schema(),
622 root_expr: original_expr,
623 top_most_unnest: None,
624 consecutive_unnest: vec![],
625 inner_projection_exprs,
626 columns_unnestings: unnest_placeholder_columns,
627 transformed_root_exprs: None,
628 };
629
630 let Transformed {
640 data: transformed_expr,
641 transformed,
642 tnr: _,
643 } = original_expr.clone().rewrite(&mut rewriter)?;
644
645 if !transformed {
646 #[expect(deprecated)]
648 if matches!(&transformed_expr, Expr::Column(_))
649 || matches!(&transformed_expr, Expr::Wildcard { .. })
650 {
651 push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
652 Ok(vec![transformed_expr])
653 } else {
654 let column_name = transformed_expr.schema_name().to_string();
657 push_projection_dedupl(inner_projection_exprs, transformed_expr);
658 Ok(vec![Expr::Column(Column::from_name(column_name))])
659 }
660 } else {
661 if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
662 return Ok(transformed_root_exprs);
663 }
664 Ok(vec![transformed_expr])
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use std::{ops::Add, sync::Arc};
671
672 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
673 use datafusion_common::{Column, DFSchema, Result};
674 use datafusion_expr::{
675 col, lit, unnest, ColumnUnnestList, EmptyRelation, LogicalPlan,
676 };
677 use datafusion_functions::core::expr_ext::FieldAccessor;
678 use datafusion_functions_aggregate::expr_fn::count;
679
680 use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
681 use indexmap::IndexMap;
682
683 fn column_unnests_eq(
684 l: Vec<&str>,
685 r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
686 ) {
687 let r_formatted: Vec<String> = r
688 .iter()
689 .map(|i| match i.1 {
690 None => format!("{}", i.0),
691 Some(vec) => format!(
692 "{}=>[{}]",
693 i.0,
694 vec.iter()
695 .map(|i| format!("{i}"))
696 .collect::<Vec<String>>()
697 .join(", ")
698 ),
699 })
700 .collect();
701 let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
702 assert_eq!(l_formatted, r_formatted);
703 }
704
705 #[test]
706 fn test_transform_bottom_unnest_recursive() -> Result<()> {
707 let schema = Schema::new(vec![
708 Field::new(
709 "3d_col",
710 ArrowDataType::List(Arc::new(Field::new(
711 "2d_col",
712 ArrowDataType::List(Arc::new(Field::new(
713 "elements",
714 ArrowDataType::Int64,
715 true,
716 ))),
717 true,
718 ))),
719 true,
720 ),
721 Field::new("i64_col", ArrowDataType::Int64, true),
722 ]);
723
724 let dfschema = DFSchema::try_from(schema)?;
725
726 let input = LogicalPlan::EmptyRelation(EmptyRelation {
727 produce_one_row: false,
728 schema: Arc::new(dfschema),
729 });
730
731 let mut unnest_placeholder_columns = IndexMap::new();
732 let mut inner_projection_exprs = vec![];
733
734 let original_expr = unnest(unnest(col("3d_col")))
736 .add(unnest(unnest(col("3d_col"))))
737 .add(col("i64_col"));
738 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
739 &input,
740 &mut unnest_placeholder_columns,
741 &mut inner_projection_exprs,
742 &original_expr,
743 )?;
744 assert_eq!(
746 transformed_exprs,
747 vec![col("__unnest_placeholder(3d_col,depth=2)")
748 .alias("UNNEST(UNNEST(3d_col))")
749 .add(
750 col("__unnest_placeholder(3d_col,depth=2)")
751 .alias("UNNEST(UNNEST(3d_col))")
752 )
753 .add(col("i64_col"))]
754 );
755 column_unnests_eq(
756 vec![
757 "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
758 ],
759 &unnest_placeholder_columns,
760 );
761
762 assert_eq!(
765 inner_projection_exprs,
766 vec![
767 col("3d_col").alias("__unnest_placeholder(3d_col)"),
768 col("i64_col")
769 ]
770 );
771
772 let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
774 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
775 &input,
776 &mut unnest_placeholder_columns,
777 &mut inner_projection_exprs,
778 &original_expr_2,
779 )?;
780
781 assert_eq!(
782 transformed_exprs,
783 vec![
784 (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
785 .alias("2d_col")
786 ]
787 );
788 column_unnests_eq(
789 vec!["__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]"],
790 &unnest_placeholder_columns,
791 );
792 assert_eq!(
795 inner_projection_exprs,
796 vec![
797 col("3d_col").alias("__unnest_placeholder(3d_col)"),
798 col("i64_col")
799 ]
800 );
801
802 Ok(())
803 }
804
805 #[test]
806 fn test_transform_bottom_unnest() -> Result<()> {
807 let schema = Schema::new(vec![
808 Field::new(
809 "struct_col",
810 ArrowDataType::Struct(Fields::from(vec![
811 Field::new("field1", ArrowDataType::Int32, false),
812 Field::new("field2", ArrowDataType::Int32, false),
813 ])),
814 false,
815 ),
816 Field::new(
817 "array_col",
818 ArrowDataType::List(Arc::new(Field::new_list_field(
819 ArrowDataType::Int64,
820 true,
821 ))),
822 true,
823 ),
824 Field::new("int_col", ArrowDataType::Int32, false),
825 ]);
826
827 let dfschema = DFSchema::try_from(schema)?;
828
829 let input = LogicalPlan::EmptyRelation(EmptyRelation {
830 produce_one_row: false,
831 schema: Arc::new(dfschema),
832 });
833
834 let mut unnest_placeholder_columns = IndexMap::new();
835 let mut inner_projection_exprs = vec![];
836
837 let original_expr = unnest(col("struct_col"));
839 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
840 &input,
841 &mut unnest_placeholder_columns,
842 &mut inner_projection_exprs,
843 &original_expr,
844 )?;
845 assert_eq!(
846 transformed_exprs,
847 vec![
848 col("__unnest_placeholder(struct_col).field1"),
849 col("__unnest_placeholder(struct_col).field2"),
850 ]
851 );
852 column_unnests_eq(
853 vec!["__unnest_placeholder(struct_col)"],
854 &unnest_placeholder_columns,
855 );
856 assert_eq!(
859 inner_projection_exprs,
860 vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
861 );
862
863 let original_expr = unnest(col("array_col")).add(lit(1i64));
865 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
866 &input,
867 &mut unnest_placeholder_columns,
868 &mut inner_projection_exprs,
869 &original_expr,
870 )?;
871 column_unnests_eq(
872 vec![
873 "__unnest_placeholder(struct_col)",
874 "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
875 ],
876 &unnest_placeholder_columns,
877 );
878 assert_eq!(
880 transformed_exprs,
881 vec![col("__unnest_placeholder(array_col,depth=1)")
882 .alias("UNNEST(array_col)")
883 .add(lit(1i64))]
884 );
885
886 assert_eq!(
890 inner_projection_exprs,
891 vec![
892 col("struct_col").alias("__unnest_placeholder(struct_col)"),
893 col("array_col").alias("__unnest_placeholder(array_col)")
894 ]
895 );
896
897 Ok(())
898 }
899
900 #[test]
902 fn test_transform_non_consecutive_unnests() -> Result<()> {
903 let schema = Schema::new(vec![
906 Field::new(
907 "struct_list",
908 ArrowDataType::List(Arc::new(Field::new(
909 "element",
910 ArrowDataType::Struct(Fields::from(vec![
911 Field::new(
912 "subfield1",
914 ArrowDataType::List(Arc::new(Field::new(
915 "i64_element",
916 ArrowDataType::Int64,
917 true,
918 ))),
919 true,
920 ),
921 Field::new(
922 "subfield2",
924 ArrowDataType::List(Arc::new(Field::new(
925 "utf8_element",
926 ArrowDataType::Utf8,
927 true,
928 ))),
929 true,
930 ),
931 ])),
932 true,
933 ))),
934 true,
935 ),
936 Field::new("int_col", ArrowDataType::Int32, false),
937 ]);
938
939 let dfschema = DFSchema::try_from(schema)?;
940
941 let input = LogicalPlan::EmptyRelation(EmptyRelation {
942 produce_one_row: false,
943 schema: Arc::new(dfschema),
944 });
945
946 let mut unnest_placeholder_columns = IndexMap::new();
947 let mut inner_projection_exprs = vec![];
948
949 let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
951 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
952 &input,
953 &mut unnest_placeholder_columns,
954 &mut inner_projection_exprs,
955 &select_expr1,
956 )?;
957 assert_eq!(
959 transformed_exprs,
960 vec![unnest(
961 col("__unnest_placeholder(struct_list,depth=1)")
962 .alias("UNNEST(struct_list)")
963 .field("subfield1")
964 )]
965 );
966
967 column_unnests_eq(
968 vec![
969 "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
970 ],
971 &unnest_placeholder_columns,
972 );
973
974 assert_eq!(
975 inner_projection_exprs,
976 vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
977 );
978
979 let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
981 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
982 &input,
983 &mut unnest_placeholder_columns,
984 &mut inner_projection_exprs,
985 &select_expr2,
986 )?;
987 assert_eq!(
989 transformed_exprs,
990 vec![unnest(
991 col("__unnest_placeholder(struct_list,depth=1)")
992 .alias("UNNEST(struct_list)")
993 .field("subfield2")
994 )]
995 );
996
997 column_unnests_eq(
1000 vec![
1001 "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1002 ],
1003 &unnest_placeholder_columns,
1004 );
1005
1006 assert_eq!(
1007 inner_projection_exprs,
1008 vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1009 );
1010
1011 Ok(())
1012 }
1013
1014 #[test]
1015 fn test_resolve_positions_to_exprs() -> Result<()> {
1016 let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1017
1018 let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1020 assert_eq!(resolved, col("c1"));
1021
1022 let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1024 assert!(resolved.is_err_and(|e| e.message().contains(
1025 "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1026 )));
1027
1028 let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1029 assert!(resolved.is_err_and(|e| e.message().contains(
1030 "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1031 )));
1032
1033 let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1035 assert_eq!(resolved, lit("text"));
1036
1037 let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1038 assert_eq!(resolved, col("fake"));
1039
1040 Ok(())
1041 }
1042}