datafusion_functions_nested/
resize.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_resize function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    new_null_array, Array, ArrayRef, Capacities, GenericListArray, Int64Array,
23    MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
24};
25use arrow::buffer::OffsetBuffer;
26use arrow::datatypes::DataType;
27use arrow::datatypes::{ArrowNativeType, Field};
28use arrow::datatypes::{
29    DataType::{FixedSizeList, LargeList, List},
30    FieldRef,
31};
32use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
33use datafusion_common::utils::ListCoercion;
34use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
35use datafusion_expr::{
36    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
37    ScalarUDFImpl, Signature, TypeSignature, Volatility,
38};
39use datafusion_macros::user_doc;
40use std::any::Any;
41use std::sync::Arc;
42
43make_udf_expr_and_func!(
44    ArrayResize,
45    array_resize,
46    array size value,
47    "returns an array with the specified size filled with the given value.",
48    array_resize_udf
49);
50
51#[user_doc(
52    doc_section(label = "Array Functions"),
53    description = "Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.",
54    syntax_example = "array_resize(array, size, value)",
55    sql_example = r#"```sql
56> select array_resize([1, 2, 3], 5, 0);
57+-------------------------------------+
58| array_resize(List([1,2,3],5,0))     |
59+-------------------------------------+
60| [1, 2, 3, 0, 0]                     |
61+-------------------------------------+
62```"#,
63    argument(
64        name = "array",
65        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
66    ),
67    argument(name = "size", description = "New size of given array."),
68    argument(
69        name = "value",
70        description = "Defines new elements' value or empty if value is not set."
71    )
72)]
73#[derive(Debug)]
74pub struct ArrayResize {
75    signature: Signature,
76    aliases: Vec<String>,
77}
78
79impl Default for ArrayResize {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl ArrayResize {
86    pub fn new() -> Self {
87        Self {
88            signature: Signature::one_of(
89                vec![
90                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
91                        arguments: vec![
92                            ArrayFunctionArgument::Array,
93                            ArrayFunctionArgument::Index,
94                        ],
95                        array_coercion: Some(ListCoercion::FixedSizedListToList),
96                    }),
97                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
98                        arguments: vec![
99                            ArrayFunctionArgument::Array,
100                            ArrayFunctionArgument::Index,
101                            ArrayFunctionArgument::Element,
102                        ],
103                        array_coercion: Some(ListCoercion::FixedSizedListToList),
104                    }),
105                ],
106                Volatility::Immutable,
107            ),
108            aliases: vec!["list_resize".to_string()],
109        }
110    }
111}
112
113impl ScalarUDFImpl for ArrayResize {
114    fn as_any(&self) -> &dyn Any {
115        self
116    }
117
118    fn name(&self) -> &str {
119        "array_resize"
120    }
121
122    fn signature(&self) -> &Signature {
123        &self.signature
124    }
125
126    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
127        match &arg_types[0] {
128            List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
129            LargeList(field) => Ok(LargeList(Arc::clone(field))),
130            DataType::Null => {
131                Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true))))
132            }
133            _ => exec_err!(
134                "Not reachable, data_type should be List, LargeList or FixedSizeList"
135            ),
136        }
137    }
138
139    fn invoke_with_args(
140        &self,
141        args: datafusion_expr::ScalarFunctionArgs,
142    ) -> Result<ColumnarValue> {
143        make_scalar_function(array_resize_inner)(&args.args)
144    }
145
146    fn aliases(&self) -> &[String] {
147        &self.aliases
148    }
149
150    fn documentation(&self) -> Option<&Documentation> {
151        self.doc()
152    }
153}
154
155/// array_resize SQL function
156pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
157    if arg.len() < 2 || arg.len() > 3 {
158        return exec_err!("array_resize needs two or three arguments");
159    }
160
161    let array = &arg[0];
162
163    // Checks if entire array is null
164    if array.logical_null_count() == array.len() {
165        let return_type = match array.data_type() {
166            List(field) => List(Arc::clone(field)),
167            LargeList(field) => LargeList(Arc::clone(field)),
168            _ => {
169                return exec_err!(
170                    "array_resize does not support type '{:?}'.",
171                    array.data_type()
172                )
173            }
174        };
175        return Ok(new_null_array(&return_type, array.len()));
176    }
177
178    let new_len = as_int64_array(&arg[1])?;
179    let new_element = if arg.len() == 3 {
180        Some(Arc::clone(&arg[2]))
181    } else {
182        None
183    };
184
185    match &arg[0].data_type() {
186        List(field) => {
187            let array = as_list_array(&arg[0])?;
188            general_list_resize::<i32>(array, new_len, field, new_element)
189        }
190        LargeList(field) => {
191            let array = as_large_list_array(&arg[0])?;
192            general_list_resize::<i64>(array, new_len, field, new_element)
193        }
194        array_type => exec_err!("array_resize does not support type '{array_type:?}'."),
195    }
196}
197
198/// array_resize keep the original array and append the default element to the end
199fn general_list_resize<O: OffsetSizeTrait + TryInto<i64>>(
200    array: &GenericListArray<O>,
201    count_array: &Int64Array,
202    field: &FieldRef,
203    default_element: Option<ArrayRef>,
204) -> Result<ArrayRef> {
205    let data_type = array.value_type();
206
207    let values = array.values();
208    let original_data = values.to_data();
209
210    // create default element array
211    let default_element = if let Some(default_element) = default_element {
212        default_element
213    } else {
214        let null_scalar = ScalarValue::try_from(&data_type)?;
215        null_scalar.to_array_of_size(original_data.len())?
216    };
217    let default_value_data = default_element.to_data();
218
219    // create a mutable array to store the original data
220    let capacity = Capacities::Array(original_data.len() + default_value_data.len());
221    let mut offsets = vec![O::usize_as(0)];
222    let mut mutable = MutableArrayData::with_capacities(
223        vec![&original_data, &default_value_data],
224        false,
225        capacity,
226    );
227
228    let mut null_builder = NullBufferBuilder::new(array.len());
229
230    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
231        if array.is_null(row_index) {
232            null_builder.append_null();
233            offsets.push(offsets[row_index]);
234            continue;
235        }
236        null_builder.append_non_null();
237
238        let count = count_array.value(row_index).to_usize().ok_or_else(|| {
239            internal_datafusion_err!("array_resize: failed to convert size to usize")
240        })?;
241        let count = O::usize_as(count);
242        let start = offset_window[0];
243        if start + count > offset_window[1] {
244            let extra_count =
245                (start + count - offset_window[1]).try_into().map_err(|_| {
246                    internal_datafusion_err!(
247                        "array_resize: failed to convert size to i64"
248                    )
249                })?;
250            let end = offset_window[1];
251            mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
252            // append default element
253            for _ in 0..extra_count {
254                mutable.extend(1, row_index, row_index + 1);
255            }
256        } else {
257            let end = start + count;
258            mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
259        };
260        offsets.push(offsets[row_index] + count);
261    }
262
263    let data = mutable.freeze();
264
265    Ok(Arc::new(GenericListArray::<O>::try_new(
266        Arc::clone(field),
267        OffsetBuffer::<O>::new(offsets.into()),
268        arrow::array::make_array(data),
269        null_builder.finish(),
270    )?))
271}