Skip to main content

datafusion_functions_nested/
array_transform.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`datafusion_expr::HigherOrderUDF`] definitions for array_transform function.
19
20use arrow::{
21    array::{Array, ArrayRef, AsArray, LargeListArray, ListArray},
22    compute::take_arrays,
23    datatypes::{DataType, Field, FieldRef},
24};
25use datafusion_common::{
26    Result, exec_err, plan_err,
27    utils::{adjust_offsets_for_slice, list_values_row_number, take_function_args},
28};
29use datafusion_expr::{
30    ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs,
31    HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda,
32    Volatility,
33};
34use datafusion_macros::user_doc;
35use std::sync::Arc;
36
37use crate::lambda_utils::{
38    ListValuesResult, coerce_single_list_arg, extract_list_values,
39    single_list_lambda_parameters,
40};
41
42make_higher_order_function_expr_and_func!(
43    ArrayTransform,
44    array_transform,
45    array lambda,
46    "transforms the values of an array",
47    array_transform_higher_order_function
48);
49
50#[user_doc(
51    doc_section(label = "Array Functions"),
52    description = "transforms the values of an array",
53    syntax_example = "array_transform(array, x -> x*2)",
54    sql_example = r#"```sql
55> select array_transform([1, 2, 3, 4, 5], x -> x*2);
56+-------------------------------------------+
57| array_transform([1, 2, 3, 4, 5], x -> x*2)       |
58+-------------------------------------------+
59| [2, 4, 6, 8, 10]                          |
60+-------------------------------------------+
61```"#,
62    argument(
63        name = "array",
64        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
65    ),
66    argument(name = "lambda", description = "Lambda")
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayTransform {
70    signature: HigherOrderSignature,
71    aliases: Vec<String>,
72}
73
74impl Default for ArrayTransform {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl ArrayTransform {
81    pub fn new() -> Self {
82        Self {
83            signature: HigherOrderSignature::exact(
84                vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
85                Volatility::Immutable,
86            ),
87            aliases: vec![String::from("list_transform")],
88        }
89    }
90}
91
92impl HigherOrderUDFImpl for ArrayTransform {
93    fn name(&self) -> &str {
94        "array_transform"
95    }
96
97    fn aliases(&self) -> &[String] {
98        &self.aliases
99    }
100
101    fn signature(&self) -> &HigherOrderSignature {
102        &self.signature
103    }
104
105    fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
106        coerce_single_list_arg(self.name(), arg_types)
107    }
108
109    fn lambda_parameters(
110        &self,
111        _step: usize,
112        fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
113    ) -> Result<LambdaParametersProgress> {
114        single_list_lambda_parameters(self.name(), fields)
115    }
116
117    fn return_field_from_args(
118        &self,
119        args: HigherOrderReturnFieldArgs,
120    ) -> Result<Arc<Field>> {
121        let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
122            take_function_args(self.name(), args.arg_fields)?
123        else {
124            return plan_err!("{} expects a value followed by a lambda", self.name());
125        };
126
127        //TODO: should metadata be copied into the transformed array?
128
129        // lambda is the resulting field of executing the lambda body
130        // with the parameters returned in lambda_parameters
131        let field = Arc::new(Field::new(
132            Field::LIST_FIELD_DEFAULT_NAME,
133            lambda.data_type().clone(),
134            lambda.is_nullable(),
135        ));
136
137        let return_type = match list.data_type() {
138            DataType::List(_) => DataType::List(field),
139            DataType::LargeList(_) => DataType::LargeList(field),
140            other => plan_err!("expected list, got {other}")?,
141        };
142
143        Ok(Arc::new(Field::new("", return_type, list.is_nullable())))
144    }
145
146    fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
147        let [list, lambda] = take_function_args(self.name(), &args.args)?;
148        let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
149        else {
150            return plan_err!("{} expects a value followed by a lambda", self.name());
151        };
152
153        let list_array = list.to_array(args.number_rows)?;
154
155        let list_values = match extract_list_values(&list_array, args.return_type())? {
156            ListValuesResult::EarlyReturn(v) => return Ok(v),
157            ListValuesResult::Values(v) => v,
158        };
159
160        // by passing closures, lambda.evaluate can evaluate only those actually needed
161        let values_param = || Ok(Arc::clone(&list_values));
162
163        // call the transforming lambda
164        let transformed_values = lambda
165            .evaluate(&[&values_param], |arrays| {
166                // if any column got captured, we need to adjust it to the values arrays,
167                // duplicating values of list with multitple values and removing values of empty lists
168                let indices = list_values_row_number(&list_array)?;
169                Ok(take_arrays(arrays, &indices, None)?)
170            })?
171            .into_array(list_values.len())?;
172
173        let field = match args.return_field.data_type() {
174            DataType::List(field) | DataType::LargeList(field) => Arc::clone(field),
175            _ => {
176                return exec_err!(
177                    "{} expected ScalarFunctionArgs.return_field to be a list, got {}",
178                    self.name(),
179                    args.return_field
180                );
181            }
182        };
183
184        let transformed_list = match list_array.data_type() {
185            DataType::List(_) => {
186                let list = list_array.as_list();
187
188                // since we called list_values above which would return sliced values for
189                // a sliced list, we must adjust the offsets here as otherwise they would be invalid
190                let adjusted_offsets = adjust_offsets_for_slice(list);
191
192                Arc::new(ListArray::new(
193                    field,
194                    adjusted_offsets,
195                    transformed_values,
196                    list.nulls().cloned(),
197                )) as ArrayRef
198            }
199            DataType::LargeList(_) => {
200                let large_list = list_array.as_list();
201
202                // since we called list_values above which would return sliced values for
203                // a sliced list, we must adjust the offsets here as otherwise they would be invalid
204                let adjusted_offsets = adjust_offsets_for_slice(large_list);
205
206                Arc::new(LargeListArray::new(
207                    field,
208                    adjusted_offsets,
209                    transformed_values,
210                    large_list.nulls().cloned(),
211                ))
212            }
213            other => exec_err!("expected list, got {other}")?,
214        };
215
216        Ok(ColumnarValue::Array(transformed_list))
217    }
218
219    fn documentation(&self) -> Option<&Documentation> {
220        self.doc()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use arrow::{
227        array::{Array, AsArray},
228        buffer::{NullBuffer, OffsetBuffer},
229    };
230
231    use crate::array_transform::array_transform_higher_order_function;
232    use crate::lambda_utils::test_utils::{create_i32_list, eval_hof_on_i32_list, v};
233    use datafusion_expr::lit;
234
235    fn divide_100_by(
236        list: impl Array + Clone + 'static,
237    ) -> datafusion_common::Result<arrow::array::ArrayRef> {
238        eval_hof_on_i32_list(
239            array_transform_higher_order_function(),
240            list,
241            lit(100i32) / v(),
242        )
243    }
244
245    #[test]
246    fn transform_on_sliced_list_should_not_evaluate_on_unreachable_values() {
247        let list = create_i32_list(
248            vec![
249                // Have 0 here so if the expression is called on data that it will fail
250                0, 4, 100, 25, 20, 5, 2, 1, 10,
251            ],
252            OffsetBuffer::<i32>::from_lengths(vec![1, 3, 4, 1]),
253            None,
254        )
255        .slice(1, 3);
256
257        let res = divide_100_by(list).unwrap();
258
259        let actual_list = res.as_list::<i32>();
260
261        let expected_list = create_i32_list(
262            vec![25, 1, 4, 5, 20, 50, 100, 10],
263            OffsetBuffer::<i32>::from_lengths(vec![3, 4, 1]),
264            None,
265        );
266
267        assert_eq!(actual_list, &expected_list);
268    }
269
270    #[test]
271    fn transform_function_should_not_be_evaluated_on_values_underlying_null() {
272        let list = create_i32_list(
273            // 0 here for one of the values behind null, so if it will be evaluated
274            // it will fail due to divide by 0
275            vec![100, 20, 10, 0, 1, 2, 0, 1, 50],
276            OffsetBuffer::<i32>::from_lengths(vec![3, 4, 2]),
277            Some(NullBuffer::from(vec![true, false, true])),
278        );
279
280        let res = divide_100_by(list).unwrap();
281
282        let actual_list = res.as_list::<i32>();
283
284        let expected_list = create_i32_list(
285            vec![1, 5, 10, 100, 2],
286            OffsetBuffer::<i32>::from_lengths(vec![3, 0, 2]),
287            Some(NullBuffer::from(vec![true, false, true])),
288        );
289
290        assert_eq!(actual_list.data_type(), expected_list.data_type());
291        assert_eq!(actual_list, &expected_list);
292    }
293}