Skip to main content

datafusion_functions_nested/
resize.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_resize function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, Capacities, GenericListArray, Int64Array, MutableArrayData,
23    NullBufferBuilder, OffsetSizeTrait, new_null_array,
24};
25use arrow::buffer::OffsetBuffer;
26use arrow::datatypes::DataType;
27use arrow::datatypes::{ArrowNativeType, Field};
28use arrow::datatypes::{
29    DataType::{LargeList, List},
30    FieldRef,
31};
32use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
33use datafusion_common::utils::ListCoercion;
34use datafusion_common::{Result, ScalarValue, exec_err, internal_datafusion_err};
35use datafusion_expr::{
36    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
37    ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
38};
39use datafusion_macros::user_doc;
40use std::sync::Arc;
41
42make_udf_expr_and_func!(
43    ArrayResize,
44    array_resize,
45    array size value,
46    "returns an array with the specified size filled with the given value.",
47    array_resize_udf
48);
49
50#[user_doc(
51    doc_section(label = "Array Functions"),
52    description = "Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.",
53    syntax_example = "array_resize(array, size, value)",
54    sql_example = r#"```sql
55> select array_resize([1, 2, 3], 5, 0);
56+-------------------------------------+
57| array_resize(List([1,2,3],5,0))     |
58+-------------------------------------+
59| [1, 2, 3, 0, 0]                     |
60+-------------------------------------+
61```"#,
62    argument(
63        name = "array",
64        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
65    ),
66    argument(name = "size", description = "New size of given array."),
67    argument(
68        name = "value",
69        description = "Defines new elements' value or empty if value is not set."
70    )
71)]
72#[derive(Debug, PartialEq, Eq, Hash)]
73pub struct ArrayResize {
74    signature: Signature,
75    aliases: Vec<String>,
76}
77
78impl Default for ArrayResize {
79    fn default() -> Self {
80        Self::new()
81    }
82}
83
84impl ArrayResize {
85    pub fn new() -> Self {
86        Self {
87            signature: Signature::one_of(
88                vec![
89                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
90                        arguments: vec![
91                            ArrayFunctionArgument::Array,
92                            ArrayFunctionArgument::Index,
93                        ],
94                        array_coercion: Some(ListCoercion::FixedSizedListToList),
95                    }),
96                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
97                        arguments: vec![
98                            ArrayFunctionArgument::Array,
99                            ArrayFunctionArgument::Index,
100                            ArrayFunctionArgument::Element,
101                        ],
102                        array_coercion: Some(ListCoercion::FixedSizedListToList),
103                    }),
104                ],
105                Volatility::Immutable,
106            ),
107            aliases: vec!["list_resize".to_string()],
108        }
109    }
110}
111
112impl ScalarUDFImpl for ArrayResize {
113    fn name(&self) -> &str {
114        "array_resize"
115    }
116
117    fn signature(&self) -> &Signature {
118        &self.signature
119    }
120
121    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
122        match &arg_types[0] {
123            List(field) => Ok(List(Arc::clone(field))),
124            LargeList(field) => Ok(LargeList(Arc::clone(field))),
125            DataType::Null => {
126                Ok(List(Arc::new(Field::new_list_field(DataType::Int64, true))))
127            }
128            _ => exec_err!(
129                "Not reachable, data_type should be List, LargeList or FixedSizeList"
130            ),
131        }
132    }
133
134    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
135        make_scalar_function(array_resize_inner)(&args.args)
136    }
137
138    fn aliases(&self) -> &[String] {
139        &self.aliases
140    }
141
142    fn documentation(&self) -> Option<&Documentation> {
143        self.doc()
144    }
145}
146
147fn array_resize_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
148    if arg.len() < 2 || arg.len() > 3 {
149        return exec_err!("array_resize needs two or three arguments");
150    }
151
152    let array = &arg[0];
153
154    // Checks if entire array is null
155    if array.logical_null_count() == array.len() {
156        let return_type = match array.data_type() {
157            List(field) => List(Arc::clone(field)),
158            LargeList(field) => LargeList(Arc::clone(field)),
159            _ => {
160                return exec_err!(
161                    "array_resize does not support type '{:?}'.",
162                    array.data_type()
163                );
164            }
165        };
166        return Ok(new_null_array(&return_type, array.len()));
167    }
168
169    let new_len = as_int64_array(&arg[1])?;
170    let new_element = if arg.len() == 3 {
171        Some(Arc::clone(&arg[2]))
172    } else {
173        None
174    };
175
176    match &arg[0].data_type() {
177        List(field) => {
178            let array = as_list_array(&arg[0])?;
179            general_list_resize::<i32>(array, new_len, field, new_element)
180        }
181        LargeList(field) => {
182            let array = as_large_list_array(&arg[0])?;
183            general_list_resize::<i64>(array, new_len, field, new_element)
184        }
185        array_type => exec_err!("array_resize does not support type '{array_type}'."),
186    }
187}
188
189/// array_resize keep the original array and append the default element to the end
190fn general_list_resize<O: OffsetSizeTrait + TryInto<i64>>(
191    array: &GenericListArray<O>,
192    count_array: &Int64Array,
193    field: &FieldRef,
194    default_element: Option<ArrayRef>,
195) -> Result<ArrayRef> {
196    let data_type = array.value_type();
197
198    let values = array.values();
199    let original_data = values.to_data();
200
201    // Track the largest per-row growth so the uniform-fill fast path can
202    // materialize one reusable fill buffer of the required size.
203    let mut max_extra: usize = 0;
204    let mut output_values_len: usize = 0;
205    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
206        if array.is_null(row_index) {
207            continue;
208        }
209        let target_count = count_array.value(row_index).to_usize().ok_or_else(|| {
210            internal_datafusion_err!("array_resize: failed to convert size to usize")
211        })?;
212        output_values_len =
213            output_values_len.checked_add(target_count).ok_or_else(|| {
214                internal_datafusion_err!("array_resize: output size overflow")
215            })?;
216        let current_len = (offset_window[1] - offset_window[0]).to_usize().unwrap();
217        if target_count > current_len {
218            max_extra = max_extra.max(target_count - current_len);
219        }
220    }
221
222    // The fast path is valid when at least one row grows and every row would
223    // use the same fill value.
224    let use_bulk_fill = max_extra > 0
225        && match &default_element {
226            None => true,
227            Some(fill_array) => {
228                let len = fill_array.len();
229                let null_count = fill_array.logical_null_count();
230
231                len <= 1
232                    || null_count == len
233                    || (null_count == 0 && {
234                        let first = fill_array.slice(0, 1);
235                        (1..len)
236                            .all(|i| fill_array.slice(i, 1).as_ref() == first.as_ref())
237                    })
238            }
239        };
240
241    if use_bulk_fill {
242        // Fast path: materialize one reusable fill buffer for all grown rows.
243        let fill_scalar = match &default_element {
244            None => ScalarValue::try_from(&data_type)?,
245            Some(fill_array) if fill_array.logical_null_count() == fill_array.len() => {
246                ScalarValue::try_from(&data_type)?
247            }
248            Some(fill_array) => ScalarValue::try_from_array(fill_array.as_ref(), 0)?,
249        };
250        let fill_values = fill_scalar.to_array_of_size(max_extra)?;
251        let default_value_data = fill_values.to_data();
252        build_resized_list(
253            array,
254            count_array,
255            field,
256            &original_data,
257            &default_value_data,
258            output_values_len,
259            |mutable, _, extra_count| mutable.extend(1, 0, extra_count),
260        )
261    } else {
262        // Slow path: rows may need different fill values, so append from the
263        // corresponding slot in the input fill array for each grown element.
264        let fill_values = match default_element {
265            Some(fill_values) => fill_values,
266            None => {
267                let null_scalar = ScalarValue::try_from(&data_type)?;
268                null_scalar.to_array_of_size(original_data.len())?
269            }
270        };
271        let default_value_data = fill_values.to_data();
272        build_resized_list(
273            array,
274            count_array,
275            field,
276            &original_data,
277            &default_value_data,
278            output_values_len,
279            |mutable, row_index, extra_count| {
280                for _ in 0..extra_count {
281                    mutable.extend(1, row_index, row_index + 1);
282                }
283            },
284        )
285    }
286}
287
288fn build_resized_list<O, F>(
289    array: &GenericListArray<O>,
290    count_array: &Int64Array,
291    field: &FieldRef,
292    original_data: &arrow::array::ArrayData,
293    default_value_data: &arrow::array::ArrayData,
294    output_values_len: usize,
295    mut append_fill_values: F,
296) -> Result<ArrayRef>
297where
298    O: OffsetSizeTrait + TryInto<i64>,
299    F: FnMut(&mut MutableArrayData, usize, usize),
300{
301    let capacity = Capacities::Array(output_values_len);
302    let mut offsets = vec![O::usize_as(0)];
303    let mut mutable = MutableArrayData::with_capacities(
304        vec![original_data, default_value_data],
305        false,
306        capacity,
307    );
308    let mut null_builder = NullBufferBuilder::new(array.len());
309
310    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
311        if array.is_null(row_index) {
312            null_builder.append_null();
313            offsets.push(offsets[row_index]);
314            continue;
315        }
316        null_builder.append_non_null();
317
318        let count = count_array.value(row_index).to_usize().ok_or_else(|| {
319            internal_datafusion_err!("array_resize: failed to convert size to usize")
320        })?;
321        let count = O::usize_as(count);
322        let start = offset_window[0];
323        if start + count > offset_window[1] {
324            let extra_count = (start + count - offset_window[1]).to_usize().unwrap();
325            let end = offset_window[1];
326            mutable.extend(0, start.to_usize().unwrap(), end.to_usize().unwrap());
327            append_fill_values(&mut mutable, row_index, extra_count);
328        } else {
329            let end = start + count;
330            mutable.extend(0, start.to_usize().unwrap(), end.to_usize().unwrap());
331        };
332        offsets.push(offsets[row_index] + count);
333    }
334
335    let data = mutable.freeze();
336
337    Ok(Arc::new(GenericListArray::<O>::try_new(
338        Arc::clone(field),
339        OffsetBuffer::<O>::new(offsets.into()),
340        arrow::array::make_array(data),
341        null_builder.finish(),
342    )?))
343}