datafusion_functions_nested/
resize.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_resize function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, Capacities, GenericListArray, Int64Array, MutableArrayData,
23    NullBufferBuilder, OffsetSizeTrait, new_null_array,
24};
25use arrow::buffer::OffsetBuffer;
26use arrow::datatypes::DataType;
27use arrow::datatypes::{ArrowNativeType, Field};
28use arrow::datatypes::{
29    DataType::{LargeList, List},
30    FieldRef,
31};
32use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
33use datafusion_common::utils::ListCoercion;
34use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
35use datafusion_expr::{
36    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
37    ScalarUDFImpl, Signature, TypeSignature, Volatility,
38};
39use datafusion_macros::user_doc;
40use std::any::Any;
41use std::sync::Arc;
42
43make_udf_expr_and_func!(
44    ArrayResize,
45    array_resize,
46    array size value,
47    "returns an array with the specified size filled with the given value.",
48    array_resize_udf
49);
50
51#[user_doc(
52    doc_section(label = "Array Functions"),
53    description = "Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.",
54    syntax_example = "array_resize(array, size, value)",
55    sql_example = r#"```sql
56> select array_resize([1, 2, 3], 5, 0);
57+-------------------------------------+
58| array_resize(List([1,2,3],5,0))     |
59+-------------------------------------+
60| [1, 2, 3, 0, 0]                     |
61+-------------------------------------+
62```"#,
63    argument(
64        name = "array",
65        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66    ),
67    argument(name = "size", description = "New size of given array."),
68    argument(
69        name = "value",
70        description = "Defines new elements' value or empty if value is not set."
71    )
72)]
73#[derive(Debug, PartialEq, Eq, Hash)]
74pub struct ArrayResize {
75    signature: Signature,
76    aliases: Vec<String>,
77}
78
79impl Default for ArrayResize {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl ArrayResize {
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::one_of(
89                vec![
90                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
91                        arguments: vec![
92                            ArrayFunctionArgument::Array,
93                            ArrayFunctionArgument::Index,
94                        ],
95                        array_coercion: Some(ListCoercion::FixedSizedListToList),
96                    }),
97                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
98                        arguments: vec![
99                            ArrayFunctionArgument::Array,
100                            ArrayFunctionArgument::Index,
101                            ArrayFunctionArgument::Element,
102                        ],
103                        array_coercion: Some(ListCoercion::FixedSizedListToList),
104                    }),
105                ],
106                Volatility::Immutable,
107            ),
108            aliases: vec!["list_resize".to_string()],
109        }
110    }
111}
112
113impl ScalarUDFImpl for ArrayResize {
114    fn as_any(&self) -> &dyn Any {
115        self
116    }
117
118    fn name(&self) -> &str {
119        "array_resize"
120    }
121
122    fn signature(&self) -> &Signature {
123        &self.signature
124    }
125
126    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
127        match &arg_types[0] {
128            List(field) => Ok(List(Arc::clone(field))),
129            LargeList(field) => Ok(LargeList(Arc::clone(field))),
130            DataType::Null => {
131                Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true))))
132            }
133            _ => exec_err!(
134                "Not reachable, data_type should be List, LargeList or FixedSizeList"
135            ),
136        }
137    }
138
139    fn invoke_with_args(
140        &self,
141        args: datafusion_expr::ScalarFunctionArgs,
142    ) -> Result<ColumnarValue> {
143        make_scalar_function(array_resize_inner)(&args.args)
144    }
145
146    fn aliases(&self) -> &[String] {
147        &self.aliases
148    }
149
150    fn documentation(&self) -> Option<&Documentation> {
151        self.doc()
152    }
153}
154
155fn array_resize_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
156    if arg.len() < 2 || arg.len() > 3 {
157        return exec_err!("array_resize needs two or three arguments");
158    }
159
160    let array = &arg[0];
161
162    // Checks if entire array is null
163    if array.logical_null_count() == array.len() {
164        let return_type = match array.data_type() {
165            List(field) => List(Arc::clone(field)),
166            LargeList(field) => LargeList(Arc::clone(field)),
167            _ => {
168                return exec_err!(
169                    "array_resize does not support type '{:?}'.",
170                    array.data_type()
171                );
172            }
173        };
174        return Ok(new_null_array(&return_type, array.len()));
175    }
176
177    let new_len = as_int64_array(&arg[1])?;
178    let new_element = if arg.len() == 3 {
179        Some(Arc::clone(&arg[2]))
180    } else {
181        None
182    };
183
184    match &arg[0].data_type() {
185        List(field) => {
186            let array = as_list_array(&arg[0])?;
187            general_list_resize::<i32>(array, new_len, field, new_element)
188        }
189        LargeList(field) => {
190            let array = as_large_list_array(&arg[0])?;
191            general_list_resize::<i64>(array, new_len, field, new_element)
192        }
193        array_type => exec_err!("array_resize does not support type '{array_type}'."),
194    }
195}
196
197/// array_resize keep the original array and append the default element to the end
198fn general_list_resize<O: OffsetSizeTrait + TryInto<i64>>(
199    array: &GenericListArray<O>,
200    count_array: &Int64Array,
201    field: &FieldRef,
202    default_element: Option<ArrayRef>,
203) -> Result<ArrayRef> {
204    let data_type = array.value_type();
205
206    let values = array.values();
207    let original_data = values.to_data();
208
209    // create default element array
210    let default_element = if let Some(default_element) = default_element {
211        default_element
212    } else {
213        let null_scalar = ScalarValue::try_from(&data_type)?;
214        null_scalar.to_array_of_size(original_data.len())?
215    };
216    let default_value_data = default_element.to_data();
217
218    // create a mutable array to store the original data
219    let capacity = Capacities::Array(original_data.len() + default_value_data.len());
220    let mut offsets = vec![O::usize_as(0)];
221    let mut mutable = MutableArrayData::with_capacities(
222        vec![&original_data, &default_value_data],
223        false,
224        capacity,
225    );
226
227    let mut null_builder = NullBufferBuilder::new(array.len());
228
229    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
230        if array.is_null(row_index) {
231            null_builder.append_null();
232            offsets.push(offsets[row_index]);
233            continue;
234        }
235        null_builder.append_non_null();
236
237        let count = count_array.value(row_index).to_usize().ok_or_else(|| {
238            internal_datafusion_err!("array_resize: failed to convert size to usize")
239        })?;
240        let count = O::usize_as(count);
241        let start = offset_window[0];
242        if start + count > offset_window[1] {
243            let extra_count =
244                (start + count - offset_window[1]).try_into().map_err(|_| {
245                    internal_datafusion_err!(
246                        "array_resize: failed to convert size to i64"
247                    )
248                })?;
249            let end = offset_window[1];
250            mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
251            // append default element
252            for _ in 0..extra_count {
253                mutable.extend(1, row_index, row_index + 1);
254            }
255        } else {
256            let end = start + count;
257            mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
258        };
259        offsets.push(offsets[row_index] + count);
260    }
261
262    let data = mutable.freeze();
263
264    Ok(Arc::new(GenericListArray::<O>::try_new(
265        Arc::clone(field),
266        OffsetBuffer::<O>::new(offsets.into()),
267        arrow::array::make_array(data),
268        null_builder.finish(),
269    )?))
270}