datafusion_functions_nested/
planner.rs1use arrow::datatypes::DataType;
21use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
22use datafusion_expr::expr::ScalarFunction;
23use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
24#[cfg(feature = "sql")]
25use datafusion_expr::sqlparser::ast::BinaryOperator;
26use datafusion_expr::AggregateUDF;
27use datafusion_expr::{
28 planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
29 Expr, ExprSchemable, GetFieldAccess,
30};
31#[cfg(not(feature = "sql"))]
32use datafusion_expr_common::operator::Operator as BinaryOperator;
33use datafusion_functions::core::get_field as get_field_inner;
34use datafusion_functions::expr_fn::get_field;
35use datafusion_functions_aggregate::nth_value::nth_value_udaf;
36use std::sync::Arc;
37
38use crate::map::map_udf;
39use crate::{
40 array_has::{array_has_all, array_has_udf},
41 expr_fn::{array_append, array_concat, array_prepend},
42 extract::{array_element, array_slice},
43 make_array::make_array,
44};
45
46#[derive(Debug)]
47pub struct NestedFunctionPlanner;
48
49impl ExprPlanner for NestedFunctionPlanner {
50 fn plan_binary_op(
51 &self,
52 expr: RawBinaryExpr,
53 schema: &DFSchema,
54 ) -> Result<PlannerResult<RawBinaryExpr>> {
55 let RawBinaryExpr { op, left, right } = expr;
56
57 if op == BinaryOperator::StringConcat {
58 let left_type = left.get_type(schema)?;
59 let right_type = right.get_type(schema)?;
60 let left_list_ndims = list_ndims(&left_type);
61 let right_list_ndims = list_ndims(&right_type);
62
63 if left_list_ndims + right_list_ndims == 0 {
72 } else if left_list_ndims == right_list_ndims {
75 return Ok(PlannerResult::Planned(array_concat(vec![left, right])));
76 } else if left_list_ndims > right_list_ndims {
77 return Ok(PlannerResult::Planned(array_append(left, right)));
78 } else if left_list_ndims < right_list_ndims {
79 return Ok(PlannerResult::Planned(array_prepend(left, right)));
80 }
81 } else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) {
82 let left_type = left.get_type(schema)?;
83 let right_type = right.get_type(schema)?;
84 let left_list_ndims = list_ndims(&left_type);
85 let right_list_ndims = list_ndims(&right_type);
86 if left_list_ndims > 0 && right_list_ndims > 0 {
88 if op == BinaryOperator::AtArrow {
89 return Ok(PlannerResult::Planned(array_has_all(left, right)));
91 } else {
92 return Ok(PlannerResult::Planned(array_has_all(right, left)));
94 }
95 }
96 }
97
98 Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
99 }
100
101 fn plan_array_literal(
102 &self,
103 exprs: Vec<Expr>,
104 _schema: &DFSchema,
105 ) -> Result<PlannerResult<Vec<Expr>>> {
106 Ok(PlannerResult::Planned(make_array(exprs)))
107 }
108
109 fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
110 if !args.len().is_multiple_of(2) {
111 return plan_err!("make_map requires an even number of arguments");
112 }
113
114 let (keys, values): (Vec<_>, Vec<_>) =
115 args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0);
116 let keys = make_array(keys.into_iter().map(|(_, e)| e).collect());
117 let values = make_array(values.into_iter().map(|(_, e)| e).collect());
118
119 Ok(PlannerResult::Planned(Expr::ScalarFunction(
120 ScalarFunction::new_udf(map_udf(), vec![keys, values]),
121 )))
122 }
123
124 fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
125 if expr.op == BinaryOperator::Eq {
126 Ok(PlannerResult::Planned(Expr::ScalarFunction(
127 ScalarFunction::new_udf(
128 array_has_udf(),
129 vec![expr.right, expr.left],
131 ),
132 )))
133 } else {
134 plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op)
135 }
136 }
137}
138
139#[derive(Debug)]
140pub struct FieldAccessPlanner;
141impl ExprPlanner for FieldAccessPlanner {
142 fn plan_field_access(
143 &self,
144 expr: RawFieldAccessExpr,
145 schema: &DFSchema,
146 ) -> Result<PlannerResult<RawFieldAccessExpr>> {
147 let RawFieldAccessExpr { expr, field_access } = expr;
148
149 match field_access {
150 GetFieldAccess::NamedStructField { name } => {
152 Ok(PlannerResult::Planned(get_field(expr, name)))
153 }
154 GetFieldAccess::ListIndex { key: index } => {
156 match expr {
157 Expr::AggregateFunction(AggregateFunction {
159 func,
160 params:
161 AggregateFunctionParams {
162 args,
163 distinct,
164 filter,
165 order_by,
166 null_treatment,
167 },
168 }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
169 Expr::AggregateFunction(AggregateFunction::new_udf(
170 nth_value_udaf(),
171 args.into_iter().chain(std::iter::once(*index)).collect(),
172 distinct,
173 filter,
174 order_by,
175 null_treatment,
176 )),
177 )),
178 _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => {
180 Ok(PlannerResult::Planned(Expr::ScalarFunction(
181 ScalarFunction::new_udf(
182 get_field_inner(),
183 vec![expr, *index],
184 ),
185 )))
186 }
187 _ => Ok(PlannerResult::Planned(array_element(expr, *index))),
188 }
189 }
190 GetFieldAccess::ListRange {
192 start,
193 stop,
194 stride,
195 } => Ok(PlannerResult::Planned(array_slice(
196 expr,
197 *start,
198 *stop,
199 Some(*stride),
200 ))),
201 }
202 }
203}
204
205fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
206 func.name() == "array_agg"
207}