1use std::vec;
21
22use arrow::datatypes::{
23 DECIMAL_DEFAULT_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType,
24};
25use datafusion_common::tree_node::{
26 Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter,
27};
28use datafusion_common::{
29 Column, DFSchemaRef, Diagnostic, HashMap, Result, ScalarValue,
30 assert_or_internal_err, exec_datafusion_err, exec_err, internal_err, plan_err,
31};
32use datafusion_expr::builder::get_struct_unnested_columns;
33use datafusion_expr::expr::{
34 Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams,
35};
36use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs};
37use datafusion_expr::{
38 ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, col, expr_vec_fmt,
39};
40
41use indexmap::IndexMap;
42use sqlparser::ast::{Ident, Value};
43
44pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result<Expr> {
46 expr.clone()
47 .transform_up(|nested_expr| {
48 match nested_expr {
49 Expr::Column(col) => {
50 let (qualifier, field) =
51 plan.schema().qualified_field_from_column(&col)?;
52 Ok(Transformed::yes(Expr::Column(Column::from((
53 qualifier, field,
54 )))))
55 }
56 _ => {
57 Ok(Transformed::no(nested_expr))
59 }
60 }
61 })
62 .data()
63}
64
65pub(crate) fn rebase_expr(
80 expr: &Expr,
81 base_exprs: &[Expr],
82 plan: &LogicalPlan,
83) -> Result<Expr> {
84 expr.clone()
85 .transform_down(|nested_expr| {
86 if base_exprs.contains(&nested_expr) {
87 Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?))
88 } else {
89 Ok(Transformed::no(nested_expr))
90 }
91 })
92 .data()
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub(crate) enum CheckColumnsMustReferenceAggregatePurpose {
97 Projection,
98 Having,
99 Qualify,
100 OrderBy,
101}
102
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub(crate) enum CheckColumnsSatisfyExprsPurpose {
105 Aggregate(CheckColumnsMustReferenceAggregatePurpose),
106}
107
108impl CheckColumnsSatisfyExprsPurpose {
109 fn message_prefix(&self) -> &'static str {
110 match self {
111 Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Projection) => {
112 "Column in SELECT must be in GROUP BY or an aggregate function"
113 }
114 Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Having) => {
115 "Column in HAVING must be in GROUP BY or an aggregate function"
116 }
117 Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::Qualify) => {
118 "Column in QUALIFY must be in GROUP BY or an aggregate function"
119 }
120 Self::Aggregate(CheckColumnsMustReferenceAggregatePurpose::OrderBy) => {
121 "Column in ORDER BY must be in GROUP BY or an aggregate function"
122 }
123 }
124 }
125
126 fn diagnostic_message(&self, expr: &Expr) -> String {
127 format!(
128 "'{expr}' must appear in GROUP BY clause because it's not an aggregate expression"
129 )
130 }
131}
132
133pub(crate) fn check_columns_satisfy_exprs(
136 columns: &[Expr],
137 exprs: &[Expr],
138 purpose: CheckColumnsSatisfyExprsPurpose,
139) -> Result<()> {
140 columns.iter().try_for_each(|c| match c {
141 Expr::Column(_) => Ok(()),
142 _ => internal_err!("Expr::Column are required"),
143 })?;
144 let column_exprs = find_column_exprs(exprs);
145 for e in &column_exprs {
146 match e {
147 Expr::GroupingSet(GroupingSet::Rollup(exprs)) => {
148 for e in exprs {
149 check_column_satisfies_expr(columns, e, purpose)?;
150 }
151 }
152 Expr::GroupingSet(GroupingSet::Cube(exprs)) => {
153 for e in exprs {
154 check_column_satisfies_expr(columns, e, purpose)?;
155 }
156 }
157 Expr::GroupingSet(GroupingSet::GroupingSets(lists_of_exprs)) => {
158 for exprs in lists_of_exprs {
159 for e in exprs {
160 check_column_satisfies_expr(columns, e, purpose)?;
161 }
162 }
163 }
164 _ => check_column_satisfies_expr(columns, e, purpose)?,
165 }
166 }
167 Ok(())
168}
169
170fn check_column_satisfies_expr(
171 columns: &[Expr],
172 expr: &Expr,
173 purpose: CheckColumnsSatisfyExprsPurpose,
174) -> Result<()> {
175 if !columns.contains(expr) {
176 let diagnostic = Diagnostic::new_error(
177 purpose.diagnostic_message(expr),
178 expr.spans().and_then(|spans| spans.first()),
179 )
180 .with_help(format!("Either add '{expr}' to GROUP BY clause, or use an aggregate function like ANY_VALUE({expr})"), None);
181
182 return plan_err!(
183 "{}: While expanding wildcard, column \"{}\" must appear in the GROUP BY clause or must be part of an aggregate function, currently only \"{}\" appears in the SELECT clause satisfies this requirement",
184 purpose.message_prefix(),
185 expr,
186 expr_vec_fmt!(columns);
187 diagnostic=diagnostic
188 );
189 }
190 Ok(())
191}
192
193pub(crate) fn extract_aliases(exprs: &[Expr]) -> HashMap<String, Expr> {
196 exprs
197 .iter()
198 .filter_map(|expr| match expr {
199 Expr::Alias(Alias { expr, name, .. }) => Some((name.clone(), *expr.clone())),
200 _ => None,
201 })
202 .collect::<HashMap<String, Expr>>()
203}
204
205pub(crate) fn resolve_positions_to_exprs(
210 expr: Expr,
211 select_exprs: &[Expr],
212) -> Result<Expr> {
213 match expr {
214 Expr::Literal(ScalarValue::Int64(Some(position)), _)
217 if position > 0_i64 && position <= select_exprs.len() as i64 =>
218 {
219 let index = (position - 1) as usize;
220 let select_expr = &select_exprs[index];
221 Ok(match select_expr {
222 Expr::Alias(Alias { expr, .. }) => *expr.clone(),
223 _ => select_expr.clone(),
224 })
225 }
226 Expr::Literal(ScalarValue::Int64(Some(position)), _) => plan_err!(
227 "Cannot find column with position {} in SELECT clause. Valid columns: 1 to {}",
228 position,
229 select_exprs.len()
230 ),
231 _ => Ok(expr),
232 }
233}
234
235pub(crate) fn resolve_aliases_to_exprs(
238 expr: Expr,
239 aliases: &HashMap<String, Expr>,
240) -> Result<Expr> {
241 expr.transform_up(|nested_expr| match nested_expr {
242 Expr::Column(c) if c.relation.is_none() => {
243 if let Some(aliased_expr) = aliases.get(&c.name) {
244 Ok(Transformed::yes(aliased_expr.clone()))
245 } else {
246 Ok(Transformed::no(Expr::Column(c)))
247 }
248 }
249 _ => Ok(Transformed::no(nested_expr)),
250 })
251 .data()
252}
253
254pub fn window_expr_common_partition_keys(window_exprs: &[Expr]) -> Result<&[Expr]> {
257 let all_partition_keys = window_exprs
258 .iter()
259 .map(|expr| match expr {
260 Expr::WindowFunction(window_fun) => {
261 let WindowFunction {
262 params: WindowFunctionParams { partition_by, .. },
263 ..
264 } = window_fun.as_ref();
265 Ok(partition_by)
266 }
267 Expr::Alias(Alias { expr, .. }) => match expr.as_ref() {
268 Expr::WindowFunction(window_fun) => {
269 let WindowFunction {
270 params: WindowFunctionParams { partition_by, .. },
271 ..
272 } = window_fun.as_ref();
273 Ok(partition_by)
274 }
275 expr => exec_err!("Impossibly got non-window expr {expr:?}"),
276 },
277 expr => exec_err!("Impossibly got non-window expr {expr:?}"),
278 })
279 .collect::<Result<Vec<_>>>()?;
280 let result = all_partition_keys
281 .iter()
282 .min_by_key(|s| s.len())
283 .ok_or_else(|| exec_datafusion_err!("No window expressions found"))?;
284 Ok(result)
285}
286
287pub(crate) fn make_decimal_type(
290 precision: Option<u64>,
291 scale: Option<u64>,
292) -> Result<DataType> {
293 let (precision, scale) = match (precision, scale) {
295 (Some(p), Some(s)) => (p as u8, s as i8),
296 (Some(p), None) => (p as u8, 0),
297 (None, Some(_)) => {
298 return plan_err!("Cannot specify only scale for decimal data type");
299 }
300 (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE),
301 };
302
303 if precision == 0
304 || precision > DECIMAL256_MAX_PRECISION
305 || scale.unsigned_abs() > precision
306 {
307 plan_err!(
308 "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`."
309 )
310 } else if precision > DECIMAL128_MAX_PRECISION
311 && precision <= DECIMAL256_MAX_PRECISION
312 {
313 Ok(DataType::Decimal256(precision, scale))
314 } else {
315 Ok(DataType::Decimal128(precision, scale))
316 }
317}
318
319pub(crate) fn normalize_ident(id: Ident) -> String {
321 match id.quote_style {
322 Some(_) => id.value,
323 None => id.value.to_ascii_lowercase(),
324 }
325}
326
327pub(crate) fn value_to_string(value: &Value) -> Option<String> {
328 match value {
329 Value::SingleQuotedString(s) => Some(s.to_string()),
330 Value::DollarQuotedString(s) => Some(s.to_string()),
331 Value::Number(_, _) | Value::Boolean(_) => Some(value.to_string()),
332 Value::UnicodeStringLiteral(s) => Some(s.to_string()),
333 Value::EscapedStringLiteral(s) => Some(s.to_string()),
334 Value::QuoteDelimitedStringLiteral(s)
335 | Value::NationalQuoteDelimitedStringLiteral(s) => Some(s.value.to_string()),
336 Value::DoubleQuotedString(_)
337 | Value::NationalStringLiteral(_)
338 | Value::SingleQuotedByteStringLiteral(_)
339 | Value::DoubleQuotedByteStringLiteral(_)
340 | Value::TripleSingleQuotedString(_)
341 | Value::TripleDoubleQuotedString(_)
342 | Value::TripleSingleQuotedByteStringLiteral(_)
343 | Value::TripleDoubleQuotedByteStringLiteral(_)
344 | Value::SingleQuotedRawStringLiteral(_)
345 | Value::DoubleQuotedRawStringLiteral(_)
346 | Value::TripleSingleQuotedRawStringLiteral(_)
347 | Value::TripleDoubleQuotedRawStringLiteral(_)
348 | Value::HexStringLiteral(_)
349 | Value::Null
350 | Value::Placeholder(_) => None,
351 }
352}
353
354pub(crate) fn rewrite_recursive_unnests_bottom_up(
355 input: &LogicalPlan,
356 unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
357 inner_projection_exprs: &mut Vec<Expr>,
358 original_exprs: &[Expr],
359) -> Result<Vec<Expr>> {
360 Ok(original_exprs
361 .iter()
362 .map(|expr| {
363 rewrite_recursive_unnest_bottom_up(
364 input,
365 unnest_placeholder_columns,
366 inner_projection_exprs,
367 expr,
368 )
369 })
370 .collect::<Result<Vec<_>>>()?
371 .into_iter()
372 .flatten()
373 .collect::<Vec<_>>())
374}
375
376pub const UNNEST_PLACEHOLDER: &str = "__unnest_placeholder";
377
378struct RecursiveUnnestRewriter<'a> {
383 input_schema: &'a DFSchemaRef,
384 root_expr: &'a Expr,
385 top_most_unnest: Option<Unnest>,
387 consecutive_unnest: Vec<Option<Unnest>>,
388 inner_projection_exprs: &'a mut Vec<Expr>,
389 columns_unnestings: &'a mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
390 transformed_root_exprs: Option<Vec<Expr>>,
391}
392impl RecursiveUnnestRewriter<'_> {
393 fn get_latest_consecutive_unnest(&self) -> Vec<Unnest> {
400 self.consecutive_unnest
401 .iter()
402 .rev()
403 .skip_while(|item| item.is_none())
404 .take_while(|item| item.is_some())
405 .to_owned()
406 .cloned()
407 .map(|item| item.unwrap())
408 .collect()
409 }
410
411 fn is_at_struct_allowed_root(&self, expr: &Expr) -> bool {
419 if expr == self.root_expr {
420 return true;
421 }
422 if let Expr::Alias(Alias { expr: inner, .. }) = self.root_expr {
424 return inner.as_ref() == expr;
425 }
426 false
427 }
428
429 fn transform(
430 &mut self,
431 level: usize,
432 alias_name: String,
433 expr_in_unnest: &Expr,
434 struct_allowed: bool,
435 ) -> Result<Vec<Expr>> {
436 let inner_expr_name = expr_in_unnest.schema_name().to_string();
437
438 let placeholder_name = format!("{UNNEST_PLACEHOLDER}({inner_expr_name})");
442 let post_unnest_name =
443 format!("{UNNEST_PLACEHOLDER}({inner_expr_name},depth={level})");
444 let placeholder_column = Column::from_name(placeholder_name.clone());
447 let field = expr_in_unnest.to_field(self.input_schema)?.1;
448 let data_type = field.data_type();
449
450 match data_type {
451 DataType::Struct(inner_fields) => {
452 assert_or_internal_err!(
453 struct_allowed,
454 "unnest on struct can only be applied at the root level of select expression"
455 );
456 push_projection_dedupl(
457 self.inner_projection_exprs,
458 expr_in_unnest.clone().alias(placeholder_name.clone()),
459 );
460 self.columns_unnestings
461 .insert(Column::from_name(placeholder_name.clone()), None);
462 Ok(get_struct_unnested_columns(&placeholder_name, inner_fields)
463 .into_iter()
464 .map(Expr::Column)
465 .collect())
466 }
467 DataType::List(_)
468 | DataType::FixedSizeList(_, _)
469 | DataType::LargeList(_) => {
470 push_projection_dedupl(
471 self.inner_projection_exprs,
472 expr_in_unnest.clone().alias(placeholder_name.clone()),
473 );
474
475 let post_unnest_expr = col(post_unnest_name.clone()).alias(alias_name);
476 let list_unnesting = self
477 .columns_unnestings
478 .entry(placeholder_column)
479 .or_insert(Some(vec![]));
480 let unnesting = ColumnUnnestList {
481 output_column: Column::from_name(post_unnest_name),
482 depth: level,
483 };
484 let list_unnestings = list_unnesting.as_mut().unwrap();
485 if !list_unnestings.contains(&unnesting) {
486 list_unnestings.push(unnesting);
487 }
488 Ok(vec![post_unnest_expr])
489 }
490 _ => {
491 internal_err!("unnest on non-list or struct type is not supported")
492 }
493 }
494 }
495}
496
497impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> {
498 type Node = Expr;
499
500 fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
505 if let Expr::Unnest(ref unnest_expr) = expr {
506 let field = unnest_expr.expr.to_field(self.input_schema)?.1;
507 let data_type = field.data_type();
508 self.consecutive_unnest.push(Some(unnest_expr.clone()));
509 if let DataType::Struct(_) = data_type {
519 self.consecutive_unnest.push(None);
520 }
521 if self.top_most_unnest.is_none() {
522 self.top_most_unnest = Some(unnest_expr.clone());
523 }
524
525 Ok(Transformed::no(expr))
526 } else {
527 self.consecutive_unnest.push(None);
528 Ok(Transformed::no(expr))
529 }
530 }
531
532 fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
561 if let Expr::Unnest(ref traversing_unnest) = expr {
562 if traversing_unnest == self.top_most_unnest.as_ref().unwrap() {
563 self.top_most_unnest = None;
564 }
565 let unnest_stack = self.get_latest_consecutive_unnest();
573
574 if traversing_unnest == unnest_stack.last().unwrap() {
580 let most_inner = unnest_stack.first().unwrap();
581 let inner_expr = most_inner.expr.as_ref();
582 let unnest_recursion = unnest_stack.len();
589 let struct_allowed =
590 self.is_at_struct_allowed_root(&expr) && unnest_recursion == 1;
591
592 let mut transformed_exprs = self.transform(
593 unnest_recursion,
594 expr.schema_name().to_string(),
595 inner_expr,
596 struct_allowed,
597 )?;
598 if struct_allowed && transformed_exprs.len() > 1 {
601 self.transformed_root_exprs = Some(transformed_exprs.clone());
602 }
603 return Ok(Transformed::new(
604 transformed_exprs.swap_remove(0),
605 true,
606 TreeNodeRecursion::Continue,
607 ));
608 }
609 } else {
610 self.consecutive_unnest.push(None);
611 }
612
613 if matches!(&expr, Expr::Column(_)) && self.top_most_unnest.is_none() {
618 push_projection_dedupl(self.inner_projection_exprs, expr.clone());
619 }
620
621 Ok(Transformed::no(expr))
622 }
623}
624
625fn push_projection_dedupl(projection: &mut Vec<Expr>, expr: Expr) {
626 let schema_name = expr.schema_name().to_string();
627 if !projection
628 .iter()
629 .any(|e| e.schema_name().to_string() == schema_name)
630 {
631 projection.push(expr);
632 }
633}
634pub(crate) fn rewrite_recursive_unnest_bottom_up(
644 input: &LogicalPlan,
645 unnest_placeholder_columns: &mut IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
646 inner_projection_exprs: &mut Vec<Expr>,
647 original_expr: &Expr,
648) -> Result<Vec<Expr>> {
649 let mut rewriter = RecursiveUnnestRewriter {
650 input_schema: input.schema(),
651 root_expr: original_expr,
652 top_most_unnest: None,
653 consecutive_unnest: vec![],
654 inner_projection_exprs,
655 columns_unnestings: unnest_placeholder_columns,
656 transformed_root_exprs: None,
657 };
658
659 let Transformed {
669 data: transformed_expr,
670 transformed,
671 tnr: _,
672 } = original_expr.clone().rewrite(&mut rewriter)?;
673
674 if !transformed {
675 #[expect(deprecated)]
677 if matches!(&transformed_expr, Expr::Column(_))
678 || matches!(&transformed_expr, Expr::Wildcard { .. })
679 {
680 push_projection_dedupl(inner_projection_exprs, transformed_expr.clone());
681 Ok(vec![transformed_expr])
682 } else {
683 let column_name = transformed_expr.schema_name().to_string();
686 push_projection_dedupl(inner_projection_exprs, transformed_expr);
687 Ok(vec![Expr::Column(Column::from_name(column_name))])
688 }
689 } else {
690 if let Some(transformed_root_exprs) = rewriter.transformed_root_exprs {
691 return Ok(transformed_root_exprs);
692 }
693 Ok(vec![transformed_expr])
694 }
695}
696
697#[cfg(test)]
698mod tests {
699 use std::{ops::Add, sync::Arc};
700
701 use arrow::datatypes::{DataType as ArrowDataType, Field, Fields, Schema};
702 use datafusion_common::{Column, DFSchema, Result};
703 use datafusion_expr::{
704 ColumnUnnestList, EmptyRelation, LogicalPlan, col, lit, unnest,
705 };
706 use datafusion_functions::core::expr_ext::FieldAccessor;
707 use datafusion_functions_aggregate::expr_fn::count;
708
709 use crate::utils::{resolve_positions_to_exprs, rewrite_recursive_unnest_bottom_up};
710 use indexmap::IndexMap;
711
712 fn column_unnests_eq(
713 l: Vec<&str>,
714 r: &IndexMap<Column, Option<Vec<ColumnUnnestList>>>,
715 ) {
716 let r_formatted: Vec<String> = r
717 .iter()
718 .map(|i| match i.1 {
719 None => format!("{}", i.0),
720 Some(vec) => format!(
721 "{}=>[{}]",
722 i.0,
723 vec.iter()
724 .map(|i| format!("{i}"))
725 .collect::<Vec<String>>()
726 .join(", ")
727 ),
728 })
729 .collect();
730 let l_formatted: Vec<String> = l.iter().map(|i| (*i).to_string()).collect();
731 assert_eq!(l_formatted, r_formatted);
732 }
733
734 #[test]
735 fn test_transform_bottom_unnest_recursive() -> Result<()> {
736 let schema = Schema::new(vec![
737 Field::new(
738 "3d_col",
739 ArrowDataType::List(Arc::new(Field::new(
740 "2d_col",
741 ArrowDataType::List(Arc::new(Field::new(
742 "elements",
743 ArrowDataType::Int64,
744 true,
745 ))),
746 true,
747 ))),
748 true,
749 ),
750 Field::new("i64_col", ArrowDataType::Int64, true),
751 ]);
752
753 let dfschema = DFSchema::try_from(schema)?;
754
755 let input = LogicalPlan::EmptyRelation(EmptyRelation {
756 produce_one_row: false,
757 schema: Arc::new(dfschema),
758 });
759
760 let mut unnest_placeholder_columns = IndexMap::new();
761 let mut inner_projection_exprs = vec![];
762
763 let original_expr = unnest(unnest(col("3d_col")))
765 .add(unnest(unnest(col("3d_col"))))
766 .add(col("i64_col"));
767 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
768 &input,
769 &mut unnest_placeholder_columns,
770 &mut inner_projection_exprs,
771 &original_expr,
772 )?;
773 assert_eq!(
775 transformed_exprs,
776 vec![
777 col("__unnest_placeholder(3d_col,depth=2)")
778 .alias("UNNEST(UNNEST(3d_col))")
779 .add(
780 col("__unnest_placeholder(3d_col,depth=2)")
781 .alias("UNNEST(UNNEST(3d_col))")
782 )
783 .add(col("i64_col"))
784 ]
785 );
786 column_unnests_eq(
787 vec![
788 "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2]",
789 ],
790 &unnest_placeholder_columns,
791 );
792
793 assert_eq!(
796 inner_projection_exprs,
797 vec![
798 col("3d_col").alias("__unnest_placeholder(3d_col)"),
799 col("i64_col")
800 ]
801 );
802
803 let original_expr_2 = unnest(col("3d_col")).alias("2d_col");
805 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
806 &input,
807 &mut unnest_placeholder_columns,
808 &mut inner_projection_exprs,
809 &original_expr_2,
810 )?;
811
812 assert_eq!(
813 transformed_exprs,
814 vec![
815 (col("__unnest_placeholder(3d_col,depth=1)").alias("UNNEST(3d_col)"))
816 .alias("2d_col")
817 ]
818 );
819 column_unnests_eq(
820 vec![
821 "__unnest_placeholder(3d_col)=>[__unnest_placeholder(3d_col,depth=2)|depth=2, __unnest_placeholder(3d_col,depth=1)|depth=1]",
822 ],
823 &unnest_placeholder_columns,
824 );
825 assert_eq!(
828 inner_projection_exprs,
829 vec![
830 col("3d_col").alias("__unnest_placeholder(3d_col)"),
831 col("i64_col")
832 ]
833 );
834
835 Ok(())
836 }
837
838 #[test]
839 fn test_transform_bottom_unnest() -> Result<()> {
840 let schema = Schema::new(vec![
841 Field::new(
842 "struct_col",
843 ArrowDataType::Struct(Fields::from(vec![
844 Field::new("field1", ArrowDataType::Int32, false),
845 Field::new("field2", ArrowDataType::Int32, false),
846 ])),
847 false,
848 ),
849 Field::new(
850 "array_col",
851 ArrowDataType::List(Arc::new(Field::new_list_field(
852 ArrowDataType::Int64,
853 true,
854 ))),
855 true,
856 ),
857 Field::new("int_col", ArrowDataType::Int32, false),
858 ]);
859
860 let dfschema = DFSchema::try_from(schema)?;
861
862 let input = LogicalPlan::EmptyRelation(EmptyRelation {
863 produce_one_row: false,
864 schema: Arc::new(dfschema),
865 });
866
867 let mut unnest_placeholder_columns = IndexMap::new();
868 let mut inner_projection_exprs = vec![];
869
870 let original_expr = unnest(col("struct_col"));
872 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
873 &input,
874 &mut unnest_placeholder_columns,
875 &mut inner_projection_exprs,
876 &original_expr,
877 )?;
878 assert_eq!(
879 transformed_exprs,
880 vec![
881 col("__unnest_placeholder(struct_col).field1"),
882 col("__unnest_placeholder(struct_col).field2"),
883 ]
884 );
885 column_unnests_eq(
886 vec!["__unnest_placeholder(struct_col)"],
887 &unnest_placeholder_columns,
888 );
889 assert_eq!(
892 inner_projection_exprs,
893 vec![col("struct_col").alias("__unnest_placeholder(struct_col)"),]
894 );
895
896 let original_expr = unnest(col("array_col")).add(lit(1i64));
898 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
899 &input,
900 &mut unnest_placeholder_columns,
901 &mut inner_projection_exprs,
902 &original_expr,
903 )?;
904 column_unnests_eq(
905 vec![
906 "__unnest_placeholder(struct_col)",
907 "__unnest_placeholder(array_col)=>[__unnest_placeholder(array_col,depth=1)|depth=1]",
908 ],
909 &unnest_placeholder_columns,
910 );
911 assert_eq!(
913 transformed_exprs,
914 vec![
915 col("__unnest_placeholder(array_col,depth=1)")
916 .alias("UNNEST(array_col)")
917 .add(lit(1i64))
918 ]
919 );
920
921 assert_eq!(
925 inner_projection_exprs,
926 vec![
927 col("struct_col").alias("__unnest_placeholder(struct_col)"),
928 col("array_col").alias("__unnest_placeholder(array_col)")
929 ]
930 );
931
932 Ok(())
933 }
934
935 #[test]
937 fn test_transform_non_consecutive_unnests() -> Result<()> {
938 let schema = Schema::new(vec![
941 Field::new(
942 "struct_list",
943 ArrowDataType::List(Arc::new(Field::new(
944 "element",
945 ArrowDataType::Struct(Fields::from(vec![
946 Field::new(
947 "subfield1",
949 ArrowDataType::List(Arc::new(Field::new(
950 "i64_element",
951 ArrowDataType::Int64,
952 true,
953 ))),
954 true,
955 ),
956 Field::new(
957 "subfield2",
959 ArrowDataType::List(Arc::new(Field::new(
960 "utf8_element",
961 ArrowDataType::Utf8,
962 true,
963 ))),
964 true,
965 ),
966 ])),
967 true,
968 ))),
969 true,
970 ),
971 Field::new("int_col", ArrowDataType::Int32, false),
972 ]);
973
974 let dfschema = DFSchema::try_from(schema)?;
975
976 let input = LogicalPlan::EmptyRelation(EmptyRelation {
977 produce_one_row: false,
978 schema: Arc::new(dfschema),
979 });
980
981 let mut unnest_placeholder_columns = IndexMap::new();
982 let mut inner_projection_exprs = vec![];
983
984 let select_expr1 = unnest(unnest(col("struct_list")).field("subfield1"));
986 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
987 &input,
988 &mut unnest_placeholder_columns,
989 &mut inner_projection_exprs,
990 &select_expr1,
991 )?;
992 assert_eq!(
994 transformed_exprs,
995 vec![unnest(
996 col("__unnest_placeholder(struct_list,depth=1)")
997 .alias("UNNEST(struct_list)")
998 .field("subfield1")
999 )]
1000 );
1001
1002 column_unnests_eq(
1003 vec![
1004 "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1005 ],
1006 &unnest_placeholder_columns,
1007 );
1008
1009 assert_eq!(
1010 inner_projection_exprs,
1011 vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1012 );
1013
1014 let select_expr2 = unnest(unnest(col("struct_list")).field("subfield2"));
1016 let transformed_exprs = rewrite_recursive_unnest_bottom_up(
1017 &input,
1018 &mut unnest_placeholder_columns,
1019 &mut inner_projection_exprs,
1020 &select_expr2,
1021 )?;
1022 assert_eq!(
1024 transformed_exprs,
1025 vec![unnest(
1026 col("__unnest_placeholder(struct_list,depth=1)")
1027 .alias("UNNEST(struct_list)")
1028 .field("subfield2")
1029 )]
1030 );
1031
1032 column_unnests_eq(
1035 vec![
1036 "__unnest_placeholder(struct_list)=>[__unnest_placeholder(struct_list,depth=1)|depth=1]",
1037 ],
1038 &unnest_placeholder_columns,
1039 );
1040
1041 assert_eq!(
1042 inner_projection_exprs,
1043 vec![col("struct_list").alias("__unnest_placeholder(struct_list)")]
1044 );
1045
1046 Ok(())
1047 }
1048
1049 #[test]
1050 fn test_resolve_positions_to_exprs() -> Result<()> {
1051 let select_exprs = vec![col("c1"), col("c2"), count(lit(1))];
1052
1053 let resolved = resolve_positions_to_exprs(lit(1i64), &select_exprs)?;
1055 assert_eq!(resolved, col("c1"));
1056
1057 let resolved = resolve_positions_to_exprs(lit(-1i64), &select_exprs);
1059 assert!(resolved.is_err_and(|e| e.message().contains(
1060 "Cannot find column with position -1 in SELECT clause. Valid columns: 1 to 3"
1061 )));
1062
1063 let resolved = resolve_positions_to_exprs(lit(5i64), &select_exprs);
1064 assert!(resolved.is_err_and(|e| e.message().contains(
1065 "Cannot find column with position 5 in SELECT clause. Valid columns: 1 to 3"
1066 )));
1067
1068 let resolved = resolve_positions_to_exprs(lit("text"), &select_exprs)?;
1070 assert_eq!(resolved, lit("text"));
1071
1072 let resolved = resolve_positions_to_exprs(col("fake"), &select_exprs)?;
1073 assert_eq!(resolved, col("fake"));
1074
1075 Ok(())
1076 }
1077}