Skip to main content

datafusion_functions_nested/
planner.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`]
19
20use arrow::datatypes::DataType;
21use datafusion_common::{DFSchema, Result, plan_err, utils::list_ndims};
22use datafusion_expr::AggregateUDF;
23use datafusion_expr::expr::ScalarFunction;
24use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
25#[cfg(feature = "sql")]
26use datafusion_expr::sqlparser::ast::BinaryOperator;
27use datafusion_expr::{
28    Expr, ExprSchemable, GetFieldAccess,
29    planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr},
30};
31#[cfg(not(feature = "sql"))]
32use datafusion_expr_common::operator::Operator as BinaryOperator;
33use datafusion_functions::core::get_field as get_field_inner;
34use datafusion_functions::expr_fn::get_field;
35use datafusion_functions_aggregate::nth_value::nth_value_udaf;
36use std::sync::Arc;
37
38use crate::map::map_udf;
39use crate::{
40    array_has::array_has_all,
41    expr_fn::{array_append, array_concat, array_prepend},
42    extract::{array_element, array_slice},
43    make_array::make_array,
44};
45
46#[derive(Debug)]
47pub struct NestedFunctionPlanner;
48
49impl ExprPlanner for NestedFunctionPlanner {
50    fn plan_binary_op(
51        &self,
52        expr: RawBinaryExpr,
53        schema: &DFSchema,
54    ) -> Result<PlannerResult<RawBinaryExpr>> {
55        let RawBinaryExpr { op, left, right } = expr;
56
57        if op == BinaryOperator::StringConcat {
58            let left_type = left.get_type(schema)?;
59            let right_type = right.get_type(schema)?;
60            let left_list_ndims = list_ndims(&left_type);
61            let right_list_ndims = list_ndims(&right_type);
62
63            // Rewrite string concat operator to function based on types
64            // if we get list || list then we rewrite it to array_concat()
65            // if we get list || non-list then we rewrite it to array_append()
66            // if we get non-list || list then we rewrite it to array_prepend()
67            // if we get string || string then we rewrite it to concat()
68
69            // We determine the target function to rewrite based on the list n-dimension, the check is not exact but sufficient.
70            // The exact validity check is handled in the actual function, so even if there is 3d list appended with 1d list, it is also fine to rewrite.
71            if left_list_ndims + right_list_ndims == 0 {
72                // TODO: concat function ignore null, but string concat takes null into consideration
73                // we can rewrite it to concat if we can configure the behaviour of concat function to the one like `string concat operator`
74            } else if left_list_ndims == right_list_ndims {
75                return Ok(PlannerResult::Planned(array_concat(vec![left, right])));
76            } else if left_list_ndims > right_list_ndims {
77                return Ok(PlannerResult::Planned(array_append(left, right)));
78            } else if left_list_ndims < right_list_ndims {
79                return Ok(PlannerResult::Planned(array_prepend(left, right)));
80            }
81        } else if matches!(op, BinaryOperator::AtArrow | BinaryOperator::ArrowAt) {
82            let left_type = left.get_type(schema)?;
83            let right_type = right.get_type(schema)?;
84            let left_list_ndims = list_ndims(&left_type);
85            let right_list_ndims = list_ndims(&right_type);
86            // if both are list
87            if left_list_ndims > 0 && right_list_ndims > 0 {
88                if op == BinaryOperator::AtArrow {
89                    // array1 @> array2 -> array_has_all(array1, array2)
90                    return Ok(PlannerResult::Planned(array_has_all(left, right)));
91                } else {
92                    // array1 <@ array2 -> array_has_all(array2, array1)
93                    return Ok(PlannerResult::Planned(array_has_all(right, left)));
94                }
95            }
96        }
97
98        Ok(PlannerResult::Original(RawBinaryExpr { op, left, right }))
99    }
100
101    fn plan_array_literal(
102        &self,
103        exprs: Vec<Expr>,
104        _schema: &DFSchema,
105    ) -> Result<PlannerResult<Vec<Expr>>> {
106        Ok(PlannerResult::Planned(make_array(exprs)))
107    }
108
109    fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
110        if !args.len().is_multiple_of(2) {
111            return plan_err!("make_map requires an even number of arguments");
112        }
113
114        let (keys, values): (Vec<_>, Vec<_>) =
115            args.into_iter().enumerate().partition(|(i, _)| i % 2 == 0);
116        let keys = make_array(keys.into_iter().map(|(_, e)| e).collect());
117        let values = make_array(values.into_iter().map(|(_, e)| e).collect());
118
119        Ok(PlannerResult::Planned(Expr::ScalarFunction(
120            ScalarFunction::new_udf(map_udf(), vec![keys, values]),
121        )))
122    }
123}
124
125#[derive(Debug)]
126pub struct FieldAccessPlanner;
127impl ExprPlanner for FieldAccessPlanner {
128    fn plan_field_access(
129        &self,
130        expr: RawFieldAccessExpr,
131        schema: &DFSchema,
132    ) -> Result<PlannerResult<RawFieldAccessExpr>> {
133        let RawFieldAccessExpr { expr, field_access } = expr;
134
135        match field_access {
136            // expr["field"] => get_field(expr, "field")
137            // Nested accesses like expr["a"]["b"] create nested get_field calls,
138            // which are then merged by the SimplifyExpressions optimizer pass via
139            // the GetFieldFunc::simplify() method.
140            GetFieldAccess::NamedStructField { name } => {
141                Ok(PlannerResult::Planned(get_field(expr, name)))
142            }
143            // expr[idx] ==> array_element(expr, idx)
144            GetFieldAccess::ListIndex { key: index } => {
145                match expr {
146                    // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index)
147                    Expr::AggregateFunction(AggregateFunction {
148                        func,
149                        params:
150                            AggregateFunctionParams {
151                                args,
152                                distinct,
153                                filter,
154                                order_by,
155                                null_treatment,
156                            },
157                    }) if is_array_agg(&func) => Ok(PlannerResult::Planned(
158                        Expr::AggregateFunction(AggregateFunction::new_udf(
159                            nth_value_udaf(),
160                            args.into_iter().chain(std::iter::once(*index)).collect(),
161                            distinct,
162                            filter,
163                            order_by,
164                            null_treatment,
165                        )),
166                    )),
167                    // special case for map access with
168                    _ if matches!(expr.get_type(schema)?, DataType::Map(_, _)) => {
169                        Ok(PlannerResult::Planned(Expr::ScalarFunction(
170                            ScalarFunction::new_udf(
171                                get_field_inner(),
172                                vec![expr, *index],
173                            ),
174                        )))
175                    }
176                    _ => Ok(PlannerResult::Planned(array_element(expr, *index))),
177                }
178            }
179            // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride)
180            GetFieldAccess::ListRange {
181                start,
182                stop,
183                stride,
184            } => Ok(PlannerResult::Planned(array_slice(
185                expr,
186                *start,
187                *stop,
188                Some(*stride),
189            ))),
190        }
191    }
192}
193
194fn is_array_agg(func: &Arc<AggregateUDF>) -> bool {
195    func.name() == "array_agg"
196}