datafusion_functions_nested/
planner.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`]
19
20use arrow::datatypes::DataType;
21use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result};
22use datafusion_expr::expr::ScalarFunction;
23use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
24#[cfg(feature = "sql")]
25use datafusion_expr::sqlparser::ast::BinaryOperator;
26use datafusion_expr::AggregateUDF;
27use datafusion_expr::{
28    planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
29    Expr, ExprSchemable, GetFieldAccess,
30};
31#[cfg(not(feature = "sql"))]
32use datafusion_expr_common::operator::Operator as BinaryOperator;
33use datafusion_functions::core::get_field as get_field_inner;
34use datafusion_functions::expr_fn::get_field;
35use datafusion_functions_aggregate::nth_value::nth_value_udaf;
36use std::sync::Arc;
37
38use crate::map::map_udf;
39use crate::{
40    array_has::{array_has_all, array_has_udf},
41    expr_fn::{array_append, array_concat, array_prepend},
42    extract::{array_element, array_slice},
43    make_array::make_array,
44};
45
46#[derive(Debug)]
47pub struct NestedFunctionPlanner;
48
49impl ExprPlanner for NestedFunctionPlanner {
50    fn plan_binary_op(
51        &self,
52        expr: RawBinaryExpr,
53        schema: &DFSchema,
54    ) -> Result<PlannerResult<RawBinaryExpr>> {
55        let RawBinaryExpr { op, left, right } = expr;
56
57        if op == BinaryOperator::StringConcat {
58            let left_type = left.get_type(schema)?;
59            let right_type = right.get_type(schema)?;
60            let left_list_ndims = list_ndims(&left_type);
61            let right_list_ndims = list_ndims(&right_type);
62
63            // Rewrite string concat operator to function based on types
64            // if we get list || list then we rewrite it to array_concat()
65            // if we get list || non-list then we rewrite it to array_append()
66            // if we get non-list || list then we rewrite it to array_prepend()
67            // if we get string || string then we rewrite it to concat()
68
69            // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
70            // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
71            if left_list_ndims + right_list_ndims == 0 {
72                // TODO: concat function ignore null, but string concat takes null into consideration
73                // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
74            } else if left_list_ndims == right_list_ndims {
75                return Ok(PlannerResult::Planned(array_concat(vec![left, right])));
76            } else if left_list_ndims > right_list_ndims {
77                return Ok(PlannerResult::Planned(array_append(left, right)));
78            } else if left_list_ndims < right_list_ndims {
79                return Ok(PlannerResult::Planned(array_prepend(left, right)));
80            }
81        } else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) {
82            let left_type = left.get_type(schema)?;
83            let right_type = right.get_type(schema)?;
84            let left_list_ndims = list_ndims(&left_type);
85            let right_list_ndims = list_ndims(&right_type);
86            // if both are list
87            if left_list_ndims > 0 && right_list_ndims > 0 {
88                if op == BinaryOperator::AtArrow {
89                    // array1 @> array2 -> array_has_all(array1, array2)
90                    return Ok(PlannerResult::Planned(array_has_all(left, right)));
91                } else {
92                    // array1 <@ array2 -> array_has_all(array2, array1)
93                    return Ok(PlannerResult::Planned(array_has_all(right, left)));
94                }
95            }
96        }
97
98        Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
99    }
100
101    fn plan_array_literal(
102        &self,
103        exprs: Vec<Expr>,
104        _schema: &DFSchema,
105    ) -> Result<PlannerResult<Vec<Expr>>> {
106        Ok(PlannerResult::Planned(make_array(exprs)))
107    }
108
109    fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
110        if !args.len().is_multiple_of(2) {
111            return plan_err!("make_map requires an even number of arguments");
112        }
113
114        let (keys, values): (Vec<_>, Vec<_>) =
115            args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0);
116        let keys = make_array(keys.into_iter().map(|(_, e)| e).collect());
117        let values = make_array(values.into_iter().map(|(_, e)| e).collect());
118
119        Ok(PlannerResult::Planned(Expr::ScalarFunction(
120            ScalarFunction::new_udf(map_udf(), vec![keys, values]),
121        )))
122    }
123
124    fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
125        if expr.op == BinaryOperator::Eq {
126            Ok(PlannerResult::Planned(Expr::ScalarFunction(
127                ScalarFunction::new_udf(
128                    array_has_udf(),
129                    // left and right are reversed here so `needle=any(haystack)` -> `array_has(haystack, needle)`
130                    vec![expr.right, expr.left],
131                ),
132            )))
133        } else {
134            plan_err!("Unsupported AnyOp: '{}', only '=' is supported", expr.op)
135        }
136    }
137}
138
139#[derive(Debug)]
140pub struct FieldAccessPlanner;
141impl ExprPlanner for FieldAccessPlanner {
142    fn plan_field_access(
143        &self,
144        expr: RawFieldAccessExpr,
145        schema: &DFSchema,
146    ) -> Result<PlannerResult<RawFieldAccessExpr>> {
147        let RawFieldAccessExpr { expr, field_access } = expr;
148
149        match field_access {
150            // expr["field"] => get_field(expr, "field")
151            GetFieldAccess::NamedStructField { name } => {
152                Ok(PlannerResult::Planned(get_field(expr, name)))
153            }
154            // expr[idx] ==> array_element(expr, idx)
155            GetFieldAccess::ListIndex { key: index } => {
156                match expr {
157                    // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
158                    Expr::AggregateFunction(AggregateFunction {
159                        func,
160                        params:
161                            AggregateFunctionParams {
162                                args,
163                                distinct,
164                                filter,
165                                order_by,
166                                null_treatment,
167                            },
168                    }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
169                        Expr::AggregateFunction(AggregateFunction::new_udf(
170                            nth_value_udaf(),
171                            args.into_iter().chain(std::iter::once(*index)).collect(),
172                            distinct,
173                            filter,
174                            order_by,
175                            null_treatment,
176                        )),
177                    )),
178                    // special case for map access with
179                    _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => {
180                        Ok(PlannerResult::Planned(Expr::ScalarFunction(
181                            ScalarFunction::new_udf(
182                                get_field_inner(),
183                                vec![expr, *index],
184                            ),
185                        )))
186                    }
187                    _ => Ok(PlannerResult::Planned(array_element(expr, *index))),
188                }
189            }
190            // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
191            GetFieldAccess::ListRange {
192                start,
193                stop,
194                stride,
195            } => Ok(PlannerResult::Planned(array_slice(
196                expr,
197                *start,
198                *stop,
199                Some(*stride),
200            ))),
201        }
202    }
203}
204
205fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
206    func.name() == "array_agg"
207}