1use arrow::{
21 array::{Array, ArrayRef, AsArray, LargeListArray, ListArray},
22 compute::take_arrays,
23 datatypes::{DataType, Field, FieldRef},
24};
25use datafusion_common::{
26 Result, exec_err, plan_err,
27 utils::{adjust_offsets_for_slice, list_values_row_number, take_function_args},
28};
29use datafusion_expr::{
30 ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs,
31 HigherOrderSignature, HigherOrderUDFImpl, LambdaParametersProgress, ValueOrLambda,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35use std::sync::Arc;
36
37use crate::lambda_utils::{
38 ListValuesResult, coerce_single_list_arg, extract_list_values,
39 single_list_lambda_parameters,
40};
41
42make_higher_order_function_expr_and_func!(
43 ArrayTransform,
44 array_transform,
45 array lambda,
46 "transforms the values of an array",
47 array_transform_higher_order_function
48);
49
50#[user_doc(
51 doc_section(label = "Array Functions"),
52 description = "transforms the values of an array",
53 syntax_example = "array_transform(array, x -> x*2)",
54 sql_example = r#"```sql
55> select array_transform([1, 2, 3, 4, 5], x -> x*2);
56+-------------------------------------------+
57| array_transform([1, 2, 3, 4, 5], x -> x*2) |
58+-------------------------------------------+
59| [2, 4, 6, 8, 10] |
60+-------------------------------------------+
61```"#,
62 argument(
63 name = "array",
64 description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
65 ),
66 argument(name = "lambda", description = "Lambda")
67)]
68#[derive(Debug, PartialEq, Eq, Hash)]
69pub struct ArrayTransform {
70 signature: HigherOrderSignature,
71 aliases: Vec<String>,
72}
73
74impl Default for ArrayTransform {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl ArrayTransform {
81 pub fn new() -> Self {
82 Self {
83 signature: HigherOrderSignature::exact(
84 vec![ValueOrLambda::Value(()), ValueOrLambda::Lambda(())],
85 Volatility::Immutable,
86 ),
87 aliases: vec![String::from("list_transform")],
88 }
89 }
90}
91
92impl HigherOrderUDFImpl for ArrayTransform {
93 fn name(&self) -> &str {
94 "array_transform"
95 }
96
97 fn aliases(&self) -> &[String] {
98 &self.aliases
99 }
100
101 fn signature(&self) -> &HigherOrderSignature {
102 &self.signature
103 }
104
105 fn coerce_value_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
106 coerce_single_list_arg(self.name(), arg_types)
107 }
108
109 fn lambda_parameters(
110 &self,
111 _step: usize,
112 fields: &[ValueOrLambda<FieldRef, Option<FieldRef>>],
113 ) -> Result<LambdaParametersProgress> {
114 single_list_lambda_parameters(self.name(), fields)
115 }
116
117 fn return_field_from_args(
118 &self,
119 args: HigherOrderReturnFieldArgs,
120 ) -> Result<Arc<Field>> {
121 let [ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)] =
122 take_function_args(self.name(), args.arg_fields)?
123 else {
124 return plan_err!("{} expects a value followed by a lambda", self.name());
125 };
126
127 let field = Arc::new(Field::new(
132 Field::LIST_FIELD_DEFAULT_NAME,
133 lambda.data_type().clone(),
134 lambda.is_nullable(),
135 ));
136
137 let return_type = match list.data_type() {
138 DataType::List(_) => DataType::List(field),
139 DataType::LargeList(_) => DataType::LargeList(field),
140 other => plan_err!("expected list, got {other}")?,
141 };
142
143 Ok(Arc::new(Field::new("", return_type, list.is_nullable())))
144 }
145
146 fn invoke_with_args(&self, args: HigherOrderFunctionArgs) -> Result<ColumnarValue> {
147 let [list, lambda] = take_function_args(self.name(), &args.args)?;
148 let (ValueOrLambda::Value(list), ValueOrLambda::Lambda(lambda)) = (list, lambda)
149 else {
150 return plan_err!("{} expects a value followed by a lambda", self.name());
151 };
152
153 let list_array = list.to_array(args.number_rows)?;
154
155 let list_values = match extract_list_values(&list_array, args.return_type())? {
156 ListValuesResult::EarlyReturn(v) => return Ok(v),
157 ListValuesResult::Values(v) => v,
158 };
159
160 let values_param = || Ok(Arc::clone(&list_values));
162
163 let transformed_values = lambda
165 .evaluate(&[&values_param], |arrays| {
166 let indices = list_values_row_number(&list_array)?;
169 Ok(take_arrays(arrays, &indices, None)?)
170 })?
171 .into_array(list_values.len())?;
172
173 let field = match args.return_field.data_type() {
174 DataType::List(field) | DataType::LargeList(field) => Arc::clone(field),
175 _ => {
176 return exec_err!(
177 "{} expected ScalarFunctionArgs.return_field to be a list, got {}",
178 self.name(),
179 args.return_field
180 );
181 }
182 };
183
184 let transformed_list = match list_array.data_type() {
185 DataType::List(_) => {
186 let list = list_array.as_list();
187
188 let adjusted_offsets = adjust_offsets_for_slice(list);
191
192 Arc::new(ListArray::new(
193 field,
194 adjusted_offsets,
195 transformed_values,
196 list.nulls().cloned(),
197 )) as ArrayRef
198 }
199 DataType::LargeList(_) => {
200 let large_list = list_array.as_list();
201
202 let adjusted_offsets = adjust_offsets_for_slice(large_list);
205
206 Arc::new(LargeListArray::new(
207 field,
208 adjusted_offsets,
209 transformed_values,
210 large_list.nulls().cloned(),
211 ))
212 }
213 other => exec_err!("expected list, got {other}")?,
214 };
215
216 Ok(ColumnarValue::Array(transformed_list))
217 }
218
219 fn documentation(&self) -> Option<&Documentation> {
220 self.doc()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use arrow::{
227 array::{Array, AsArray},
228 buffer::{NullBuffer, OffsetBuffer},
229 };
230
231 use crate::array_transform::array_transform_higher_order_function;
232 use crate::lambda_utils::test_utils::{create_i32_list, eval_hof_on_i32_list, v};
233 use datafusion_expr::lit;
234
235 fn divide_100_by(
236 list: impl Array + Clone + 'static,
237 ) -> datafusion_common::Result<arrow::array::ArrayRef> {
238 eval_hof_on_i32_list(
239 array_transform_higher_order_function(),
240 list,
241 lit(100i32) / v(),
242 )
243 }
244
245 #[test]
246 fn transform_on_sliced_list_should_not_evaluate_on_unreachable_values() {
247 let list = create_i32_list(
248 vec![
249 0, 4, 100, 25, 20, 5, 2, 1, 10,
251 ],
252 OffsetBuffer::<i32>::from_lengths(vec![1, 3, 4, 1]),
253 None,
254 )
255 .slice(1, 3);
256
257 let res = divide_100_by(list).unwrap();
258
259 let actual_list = res.as_list::<i32>();
260
261 let expected_list = create_i32_list(
262 vec![25, 1, 4, 5, 20, 50, 100, 10],
263 OffsetBuffer::<i32>::from_lengths(vec![3, 4, 1]),
264 None,
265 );
266
267 assert_eq!(actual_list, &expected_list);
268 }
269
270 #[test]
271 fn transform_function_should_not_be_evaluated_on_values_underlying_null() {
272 let list = create_i32_list(
273 vec![100, 20, 10, 0, 1, 2, 0, 1, 50],
276 OffsetBuffer::<i32>::from_lengths(vec![3, 4, 2]),
277 Some(NullBuffer::from(vec![true, false, true])),
278 );
279
280 let res = divide_100_by(list).unwrap();
281
282 let actual_list = res.as_list::<i32>();
283
284 let expected_list = create_i32_list(
285 vec![1, 5, 10, 100, 2],
286 OffsetBuffer::<i32>::from_lengths(vec![3, 0, 2]),
287 Some(NullBuffer::from(vec![true, false, true])),
288 );
289
290 assert_eq!(actual_list.data_type(), expected_list.data_type());
291 assert_eq!(actual_list, &expected_list);
292 }
293}