datafusion_functions_nested/
planner.rs1use arrow::datatypes::DataType;
21use datafusion_common::{DFSchema, Result, plan_err, utils::list_ndims};
22use datafusion_expr::AggregateUDF;
23use datafusion_expr::expr::ScalarFunction;
24use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
25#[cfg(feature = "sql")]
26use datafusion_expr::sqlparser::ast::BinaryOperator;
27use datafusion_expr::{
28 Expr, ExprSchemable, GetFieldAccess,
29 planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
30};
31#[cfg(not(feature = "sql"))]
32use datafusion_expr_common::operator::Operator as BinaryOperator;
33use datafusion_functions::core::get_field as get_field_inner;
34use datafusion_functions::expr_fn::get_field;
35use datafusion_functions_aggregate::nth_value::nth_value_udaf;
36use std::sync::Arc;
37
38use crate::map::map_udf;
39use crate::{
40 array_has::array_has_all,
41 expr_fn::{array_append, array_concat, array_prepend},
42 extract::{array_element, array_slice},
43 make_array::make_array,
44};
45
46#[derive(Debug)]
47pub struct NestedFunctionPlanner;
48
49impl ExprPlanner for NestedFunctionPlanner {
50 fn plan_binary_op(
51 &self,
52 expr: RawBinaryExpr,
53 schema: &DFSchema,
54 ) -> Result<PlannerResult<RawBinaryExpr>> {
55 let RawBinaryExpr { op, left, right } = expr;
56
57 if op == BinaryOperator::StringConcat {
58 let left_type = left.get_type(schema)?;
59 let right_type = right.get_type(schema)?;
60 let left_list_ndims = list_ndims(&left_type);
61 let right_list_ndims = list_ndims(&right_type);
62
63 if left_list_ndims + right_list_ndims == 0 {
72 } else if left_list_ndims == right_list_ndims {
75 return Ok(PlannerResult::Planned(array_concat(vec![left, right])));
76 } else if left_list_ndims > right_list_ndims {
77 return Ok(PlannerResult::Planned(array_append(left, right)));
78 } else if left_list_ndims < right_list_ndims {
79 return Ok(PlannerResult::Planned(array_prepend(left, right)));
80 }
81 } else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) {
82 let left_type = left.get_type(schema)?;
83 let right_type = right.get_type(schema)?;
84 let left_list_ndims = list_ndims(&left_type);
85 let right_list_ndims = list_ndims(&right_type);
86 if left_list_ndims > 0 && right_list_ndims > 0 {
88 if op == BinaryOperator::AtArrow {
89 return Ok(PlannerResult::Planned(array_has_all(left, right)));
91 } else {
92 return Ok(PlannerResult::Planned(array_has_all(right, left)));
94 }
95 }
96 }
97
98 Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
99 }
100
101 fn plan_array_literal(
102 &self,
103 exprs: Vec<Expr>,
104 _schema: &DFSchema,
105 ) -> Result<PlannerResult<Vec<Expr>>> {
106 Ok(PlannerResult::Planned(make_array(exprs)))
107 }
108
109 fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
110 if !args.len().is_multiple_of(2) {
111 return plan_err!("make_map requires an even number of arguments");
112 }
113
114 let (keys, values): (Vec<_>, Vec<_>) =
115 args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0);
116 let keys = make_array(keys.into_iter().map(|(_, e)| e).collect());
117 let values = make_array(values.into_iter().map(|(_, e)| e).collect());
118
119 Ok(PlannerResult::Planned(Expr::ScalarFunction(
120 ScalarFunction::new_udf(map_udf(), vec![keys, values]),
121 )))
122 }
123}
124
125#[derive(Debug)]
126pub struct FieldAccessPlanner;
127impl ExprPlanner for FieldAccessPlanner {
128 fn plan_field_access(
129 &self,
130 expr: RawFieldAccessExpr,
131 schema: &DFSchema,
132 ) -> Result<PlannerResult<RawFieldAccessExpr>> {
133 let RawFieldAccessExpr { expr, field_access } = expr;
134
135 match field_access {
136 GetFieldAccess::NamedStructField { name } => {
141 Ok(PlannerResult::Planned(get_field(expr, name)))
142 }
143 GetFieldAccess::ListIndex { key: index } => {
145 match expr {
146 Expr::AggregateFunction(AggregateFunction {
148 func,
149 params:
150 AggregateFunctionParams {
151 args,
152 distinct,
153 filter,
154 order_by,
155 null_treatment,
156 },
157 }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
158 Expr::AggregateFunction(AggregateFunction::new_udf(
159 nth_value_udaf(),
160 args.into_iter().chain(std::iter::once(*index)).collect(),
161 distinct,
162 filter,
163 order_by,
164 null_treatment,
165 )),
166 )),
167 _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => {
169 Ok(PlannerResult::Planned(Expr::ScalarFunction(
170 ScalarFunction::new_udf(
171 get_field_inner(),
172 vec![expr, *index],
173 ),
174 )))
175 }
176 _ => Ok(PlannerResult::Planned(array_element(expr, *index))),
177 }
178 }
179 GetFieldAccess::ListRange {
181 start,
182 stop,
183 stride,
184 } => Ok(PlannerResult::Planned(array_slice(
185 expr,
186 *start,
187 *stop,
188 Some(*stride),
189 ))),
190 }
191 }
192}
193
194fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
195 func.name() == "array_agg"
196}