datafusion_functions_nested/
planner.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`]
19
20use arrow::datatypes::DataType;
21use datafusion_common::ExprSchema;
22use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
23use datafusion_expr::expr::ScalarFunction;
24use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
25use datafusion_expr::AggregateUDF;
26use datafusion_expr::{
27    planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
28    sqlparser, Expr, ExprSchemable, GetFieldAccess,
29};
30use datafusion_functions::core::get_field as get_field_inner;
31use datafusion_functions::expr_fn::get_field;
32use datafusion_functions_aggregate::nth_value::nth_value_udaf;
33use std::sync::Arc;
34
35use crate::map::map_udf;
36use crate::{
37    array_has::{array_has_all, array_has_udf},
38    expr_fn::{array_append, array_concat, array_prepend},
39    extract::{array_element, array_slice},
40    make_array::make_array,
41};
42
43#[derive(Debug)]
44pub struct NestedFunctionPlanner;
45
46impl ExprPlanner for NestedFunctionPlanner {
47    fn plan_binary_op(
48        &self,
49        expr: RawBinaryExpr,
50        schema: &DFSchema,
51    ) -> Result<PlannerResult<RawBinaryExpr>> {
52        let RawBinaryExpr { op, left, right } = expr;
53
54        if op == sqlparser::ast::BinaryOperator::StringConcat {
55            let left_type = left.get_type(schema)?;
56            let right_type = right.get_type(schema)?;
57            let left_list_ndims = list_ndims(&left_type);
58            let right_list_ndims = list_ndims(&right_type);
59
60            // Rewrite string concat operator to function based on types
61            // if we get list || list then we rewrite it to array_concat()
62            // if we get list || non-list then we rewrite it to array_append()
63            // if we get non-list || list then we rewrite it to array_prepend()
64            // if we get string || string then we rewrite it to concat()
65
66            // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
67            // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
68            if left_list_ndims + right_list_ndims == 0 {
69                // TODO: concat function ignore null, but string concat takes null into consideration
70                // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
71            } else if left_list_ndims == right_list_ndims {
72                return Ok(PlannerResult::Planned(array_concat(vec![left, right])));
73            } else if left_list_ndims > right_list_ndims {
74                return Ok(PlannerResult::Planned(array_append(left, right)));
75            } else if left_list_ndims < right_list_ndims {
76                return Ok(PlannerResult::Planned(array_prepend(left, right)));
77            }
78        } else if matches!(
79            op,
80            sqlparser::ast::BinaryOperator::AtArrow
81                | sqlparser::ast::BinaryOperator::ArrowAt
82        ) {
83            let left_type = left.get_type(schema)?;
84            let right_type = right.get_type(schema)?;
85            let left_list_ndims = list_ndims(&left_type);
86            let right_list_ndims = list_ndims(&right_type);
87            // if both are list
88            if left_list_ndims > 0 && right_list_ndims > 0 {
89                if op == sqlparser::ast::BinaryOperator::AtArrow {
90                    // array1 @> array2 -> array_has_all(array1, array2)
91                    return Ok(PlannerResult::Planned(array_has_all(left, right)));
92                } else {
93                    // array1 <@ array2 -> array_has_all(array2, array1)
94                    return Ok(PlannerResult::Planned(array_has_all(right, left)));
95                }
96            }
97        }
98
99        Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
100    }
101
102    fn plan_array_literal(
103        &self,
104        exprs: Vec<Expr>,
105        _schema: &DFSchema,
106    ) -> Result<PlannerResult<Vec<Expr>>> {
107        Ok(PlannerResult::Planned(make_array(exprs)))
108    }
109
110    fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
111        if args.len() % 2 != 0 {
112            return plan_err!("make_map requires an even number of arguments");
113        }
114
115        let (keys, values): (Vec<_>, Vec<_>) =
116            args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0);
117        let keys = make_array(keys.into_iter().map(|(_, e)| e).collect());
118        let values = make_array(values.into_iter().map(|(_, e)| e).collect());
119
120        Ok(PlannerResult::Planned(Expr::ScalarFunction(
121            ScalarFunction::new_udf(map_udf(), vec![keys, values]),
122        )))
123    }
124
125    fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
126        if expr.op == sqlparser::ast::BinaryOperator::Eq {
127            Ok(PlannerResult::Planned(Expr::ScalarFunction(
128                ScalarFunction::new_udf(
129                    array_has_udf(),
130                    // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)`
131                    vec![expr.right, expr.left],
132                ),
133            )))
134        } else {
135            plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op)
136        }
137    }
138}
139
140#[derive(Debug)]
141pub struct FieldAccessPlanner;
142impl ExprPlanner for FieldAccessPlanner {
143    fn plan_field_access(
144        &self,
145        expr: RawFieldAccessExpr,
146        schema: &DFSchema,
147    ) -> Result<PlannerResult<RawFieldAccessExpr>> {
148        let RawFieldAccessExpr { expr, field_access } = expr;
149
150        match field_access {
151            // expr["field"] => get_field(expr, "field")
152            GetFieldAccess::NamedStructField { name } => {
153                Ok(PlannerResult::Planned(get_field(expr, name)))
154            }
155            // expr[idx] ==> array_element(expr, idx)
156            GetFieldAccess::ListIndex { key: index } => {
157                match expr {
158                    // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
159                    Expr::AggregateFunction(AggregateFunction {
160                        func,
161                        params:
162                            AggregateFunctionParams {
163                                args,
164                                distinct,
165                                filter,
166                                order_by,
167                                null_treatment,
168                            },
169                    }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
170                        Expr::AggregateFunction(AggregateFunction::new_udf(
171                            nth_value_udaf(),
172                            args.into_iter().chain(std::iter::once(*index)).collect(),
173                            distinct,
174                            filter,
175                            order_by,
176                            null_treatment,
177                        )),
178                    )),
179                    // special case for map access with
180                    Expr::Column(ref c)
181                        if matches!(schema.data_type(c)?, DataType::Map(_, _)) =>
182                    {
183                        Ok(PlannerResult::Planned(Expr::ScalarFunction(
184                            ScalarFunction::new_udf(
185                                get_field_inner(),
186                                vec![expr, *index],
187                            ),
188                        )))
189                    }
190                    _ => Ok(PlannerResult::Planned(array_element(expr, *index))),
191                }
192            }
193            // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
194            GetFieldAccess::ListRange {
195                start,
196                stop,
197                stride,
198            } => Ok(PlannerResult::Planned(array_slice(
199                expr,
200                *start,
201                *stop,
202                Some(*stride),
203            ))),
204        }
205    }
206}
207
208fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
209    func.name() == "array_agg"
210}