datafusion_functions_nested/
planner.rs1use arrow::datatypes::DataType;
21use datafusion_common::ExprSchema;
22use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
23use datafusion_expr::expr::ScalarFunction;
24use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
25use datafusion_expr::AggregateUDF;
26use datafusion_expr::{
27 planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
28 sqlparser, Expr, ExprSchemable, GetFieldAccess,
29};
30use datafusion_functions::core::get_field as get_field_inner;
31use datafusion_functions::expr_fn::get_field;
32use datafusion_functions_aggregate::nth_value::nth_value_udaf;
33use std::sync::Arc;
34
35use crate::map::map_udf;
36use crate::{
37 array_has::{array_has_all, array_has_udf},
38 expr_fn::{array_append, array_concat, array_prepend},
39 extract::{array_element, array_slice},
40 make_array::make_array,
41};
42
43#[derive(Debug)]
44pub struct NestedFunctionPlanner;
45
46impl ExprPlanner for NestedFunctionPlanner {
47 fn plan_binary_op(
48 &self,
49 expr: RawBinaryExpr,
50 schema: &DFSchema,
51 ) -> Result<PlannerResult<RawBinaryExpr>> {
52 let RawBinaryExpr { op, left, right } = expr;
53
54 if op == sqlparser::ast::BinaryOperator::StringConcat {
55 let left_type = left.get_type(schema)?;
56 let right_type = right.get_type(schema)?;
57 let left_list_ndims = list_ndims(&left_type);
58 let right_list_ndims = list_ndims(&right_type);
59
60 if left_list_ndims + right_list_ndims == 0 {
69 } else if left_list_ndims == right_list_ndims {
72 return Ok(PlannerResult::Planned(array_concat(vec![left, right])));
73 } else if left_list_ndims > right_list_ndims {
74 return Ok(PlannerResult::Planned(array_append(left, right)));
75 } else if left_list_ndims < right_list_ndims {
76 return Ok(PlannerResult::Planned(array_prepend(left, right)));
77 }
78 } else if matches!(
79 op,
80 sqlparser::ast::BinaryOperator::AtArrow
81 | sqlparser::ast::BinaryOperator::ArrowAt
82 ) {
83 let left_type = left.get_type(schema)?;
84 let right_type = right.get_type(schema)?;
85 let left_list_ndims = list_ndims(&left_type);
86 let right_list_ndims = list_ndims(&right_type);
87 if left_list_ndims > 0 && right_list_ndims > 0 {
89 if op == sqlparser::ast::BinaryOperator::AtArrow {
90 return Ok(PlannerResult::Planned(array_has_all(left, right)));
92 } else {
93 return Ok(PlannerResult::Planned(array_has_all(right, left)));
95 }
96 }
97 }
98
99 Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
100 }
101
102 fn plan_array_literal(
103 &self,
104 exprs: Vec<Expr>,
105 _schema: &DFSchema,
106 ) -> Result<PlannerResult<Vec<Expr>>> {
107 Ok(PlannerResult::Planned(make_array(exprs)))
108 }
109
110 fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
111 if args.len() % 2 != 0 {
112 return plan_err!("make_map requires an even number of arguments");
113 }
114
115 let (keys, values): (Vec<_>, Vec<_>) =
116 args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0);
117 let keys = make_array(keys.into_iter().map(|(_, e)| e).collect());
118 let values = make_array(values.into_iter().map(|(_, e)| e).collect());
119
120 Ok(PlannerResult::Planned(Expr::ScalarFunction(
121 ScalarFunction::new_udf(map_udf(), vec![keys, values]),
122 )))
123 }
124
125 fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
126 if expr.op == sqlparser::ast::BinaryOperator::Eq {
127 Ok(PlannerResult::Planned(Expr::ScalarFunction(
128 ScalarFunction::new_udf(
129 array_has_udf(),
130 vec![expr.right, expr.left],
132 ),
133 )))
134 } else {
135 plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op)
136 }
137 }
138}
139
140#[derive(Debug)]
141pub struct FieldAccessPlanner;
142impl ExprPlanner for FieldAccessPlanner {
143 fn plan_field_access(
144 &self,
145 expr: RawFieldAccessExpr,
146 schema: &DFSchema,
147 ) -> Result<PlannerResult<RawFieldAccessExpr>> {
148 let RawFieldAccessExpr { expr, field_access } = expr;
149
150 match field_access {
151 GetFieldAccess::NamedStructField { name } => {
153 Ok(PlannerResult::Planned(get_field(expr, name)))
154 }
155 GetFieldAccess::ListIndex { key: index } => {
157 match expr {
158 Expr::AggregateFunction(AggregateFunction {
160 func,
161 params:
162 AggregateFunctionParams {
163 args,
164 distinct,
165 filter,
166 order_by,
167 null_treatment,
168 },
169 }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
170 Expr::AggregateFunction(AggregateFunction::new_udf(
171 nth_value_udaf(),
172 args.into_iter().chain(std::iter::once(*index)).collect(),
173 distinct,
174 filter,
175 order_by,
176 null_treatment,
177 )),
178 )),
179 Expr::Column(ref c)
181 if matches!(schema.data_type(c)?, DataType::Map(_, _)) =>
182 {
183 Ok(PlannerResult::Planned(Expr::ScalarFunction(
184 ScalarFunction::new_udf(
185 get_field_inner(),
186 vec![expr, *index],
187 ),
188 )))
189 }
190 _ => Ok(PlannerResult::Planned(array_element(expr, *index))),
191 }
192 }
193 GetFieldAccess::ListRange {
195 start,
196 stop,
197 stride,
198 } => Ok(PlannerResult::Planned(array_slice(
199 expr,
200 *start,
201 *stop,
202 Some(*stride),
203 ))),
204 }
205 }
206}
207
208fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
209 func.name() == "array_agg"
210}