1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
1718//! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`]
1920use arrow::datatypes::DataType;
21use datafusion_common::ExprSchema;
22use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
23use datafusion_expr::expr::ScalarFunction;
24use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
25use datafusion_expr::AggregateUDF;
26use datafusion_expr::{
27 planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
28 sqlparser, Expr, ExprSchemable, GetFieldAccess,
29};
30use datafusion_functions::core::get_field as get_field_inner;
31use datafusion_functions::expr_fn::get_field;
32use datafusion_functions_aggregate::nth_value::nth_value_udaf;
33use std::sync::Arc;
3435use crate::map::map_udf;
36use crate::{
37 array_has::{array_has_all, array_has_udf},
38 expr_fn::{array_append, array_concat, array_prepend},
39 extract::{array_element, array_slice},
40 make_array::make_array,
41};
4243#[derive(Debug)]
44pub struct NestedFunctionPlanner;
4546impl ExprPlanner for NestedFunctionPlanner {
47fn plan_binary_op(
48&self,
49 expr: RawBinaryExpr,
50 schema: &DFSchema,
51 ) -> Result<PlannerResult<RawBinaryExpr>> {
52let RawBinaryExpr { op, left, right } = expr;
5354if op == sqlparser::ast::BinaryOperator::StringConcat {
55let left_type = left.get_type(schema)?;
56let right_type = right.get_type(schema)?;
57let left_list_ndims = list_ndims(&left_type);
58let right_list_ndims = list_ndims(&right_type);
5960// Rewrite string concat operator to function based on types
61 // if we get list || list then we rewrite it to array_concat()
62 // if we get list || non-list then we rewrite it to array_append()
63 // if we get non-list || list then we rewrite it to array_prepend()
64 // if we get string || string then we rewrite it to concat()
6566 // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
67 // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
68if left_list_ndims + right_list_ndims == 0 {
69// TODO: concat function ignore null, but string concat takes null into consideration
70 // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
71} else if left_list_ndims == right_list_ndims {
72return Ok(PlannerResult::Planned(array_concat(vec![left, right])));
73 } else if left_list_ndims > right_list_ndims {
74return Ok(PlannerResult::Planned(array_append(left, right)));
75 } else if left_list_ndims < right_list_ndims {
76return Ok(PlannerResult::Planned(array_prepend(left, right)));
77 }
78 } else if matches!(
79 op,
80 sqlparser::ast::BinaryOperator::AtArrow
81 | sqlparser::ast::BinaryOperator::ArrowAt
82 ) {
83let left_type = left.get_type(schema)?;
84let right_type = right.get_type(schema)?;
85let left_list_ndims = list_ndims(&left_type);
86let right_list_ndims = list_ndims(&right_type);
87// if both are list
88if left_list_ndims > 0 && right_list_ndims > 0 {
89if op == sqlparser::ast::BinaryOperator::AtArrow {
90// array1 @> array2 -> array_has_all(array1, array2)
91return Ok(PlannerResult::Planned(array_has_all(left, right)));
92 } else {
93// array1 <@ array2 -> array_has_all(array2, array1)
94return Ok(PlannerResult::Planned(array_has_all(right, left)));
95 }
96 }
97 }
9899Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
100 }
101102fn plan_array_literal(
103&self,
104 exprs: Vec<Expr>,
105 _schema: &DFSchema,
106 ) -> Result<PlannerResult<Vec<Expr>>> {
107Ok(PlannerResult::Planned(make_array(exprs)))
108 }
109110fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
111if args.len() % 2 != 0 {
112return plan_err!("make_map requires an even number of arguments");
113 }
114115let (keys, values): (Vec<_>, Vec<_>) =
116 args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0);
117let keys = make_array(keys.into_iter().map(|(_, e)| e).collect());
118let values = make_array(values.into_iter().map(|(_, e)| e).collect());
119120Ok(PlannerResult::Planned(Expr::ScalarFunction(
121 ScalarFunction::new_udf(map_udf(), vec![keys, values]),
122 )))
123 }
124125fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
126if expr.op == sqlparser::ast::BinaryOperator::Eq {
127Ok(PlannerResult::Planned(Expr::ScalarFunction(
128 ScalarFunction::new_udf(
129 array_has_udf(),
130// left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)`
131vec![expr.right, expr.left],
132 ),
133 )))
134 } else {
135plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op)
136 }
137 }
138}
139140#[derive(Debug)]
141pub struct FieldAccessPlanner;
142impl ExprPlanner for FieldAccessPlanner {
143fn plan_field_access(
144&self,
145 expr: RawFieldAccessExpr,
146 schema: &DFSchema,
147 ) -> Result<PlannerResult<RawFieldAccessExpr>> {
148let RawFieldAccessExpr { expr, field_access } = expr;
149150match field_access {
151// expr["field"] => get_field(expr, "field")
152GetFieldAccess::NamedStructField { name } => {
153Ok(PlannerResult::Planned(get_field(expr, name)))
154 }
155// expr[idx] ==> array_element(expr, idx)
156GetFieldAccess::ListIndex { key: index } => {
157match expr {
158// Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
159Expr::AggregateFunction(AggregateFunction {
160 func,
161 params:
162 AggregateFunctionParams {
163 args,
164 distinct,
165 filter,
166 order_by,
167 null_treatment,
168 },
169 }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
170 Expr::AggregateFunction(AggregateFunction::new_udf(
171 nth_value_udaf(),
172 args.into_iter().chain(std::iter::once(*index)).collect(),
173 distinct,
174 filter,
175 order_by,
176 null_treatment,
177 )),
178 )),
179// special case for map access with
180Expr::Column(ref c)
181if matches!(schema.data_type(c)?, DataType::Map(_, _)) =>
182 {
183Ok(PlannerResult::Planned(Expr::ScalarFunction(
184 ScalarFunction::new_udf(
185 get_field_inner(),
186vec![expr, *index],
187 ),
188 )))
189 }
190_ => Ok(PlannerResult::Planned(array_element(expr, *index))),
191 }
192 }
193// expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
194GetFieldAccess::ListRange {
195 start,
196 stop,
197 stride,
198 } => Ok(PlannerResult::Planned(array_slice(
199 expr,
200*start,
201*stop,
202Some(*stride),
203 ))),
204 }
205 }
206}
207208fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
209 func.name() == "array_agg"
210}