datafusion_functions_nested/
resize.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_resize function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    new_null_array, Array, ArrayRef, Capacities, GenericListArray, Int64Array,
23    MutableArrayData, NullBufferBuilder, OffsetSizeTrait,
24};
25use arrow::buffer::OffsetBuffer;
26use arrow::datatypes::ArrowNativeType;
27use arrow::datatypes::DataType;
28use arrow::datatypes::{
29    DataType::{FixedSizeList, LargeList, List},
30    FieldRef,
31};
32use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
33use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
34use datafusion_expr::{
35    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
36};
37use datafusion_macros::user_doc;
38use std::any::Any;
39use std::sync::Arc;
40
41make_udf_expr_and_func!(
42    ArrayResize,
43    array_resize,
44    array size value,
45    "returns an array with the specified size filled with the given value.",
46    array_resize_udf
47);
48
49#[user_doc(
50    doc_section(label = "Array Functions"),
51    description = "Resizes the list to contain size elements. Initializes new elements with value or empty if value is not set.",
52    syntax_example = "array_resize(array, size, value)",
53    sql_example = r#"```sql
54> select array_resize([1, 2, 3], 5, 0);
55+-------------------------------------+
56| array_resize(List([1,2,3],5,0))     |
57+-------------------------------------+
58| [1, 2, 3, 0, 0]                     |
59+-------------------------------------+
60```"#,
61    argument(
62        name = "array",
63        description = "Array expression. Can be a constant, column, or function, and any combination of array operators."
64    ),
65    argument(name = "size", description = "New size of given array."),
66    argument(
67        name = "value",
68        description = "Defines new elements' value or empty if value is not set."
69    )
70)]
71#[derive(Debug)]
72pub struct ArrayResize {
73    signature: Signature,
74    aliases: Vec<String>,
75}
76
77impl Default for ArrayResize {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83impl ArrayResize {
84    pub fn new() -> Self {
85        Self {
86            signature: Signature::variadic_any(Volatility::Immutable),
87            aliases: vec!["list_resize".to_string()],
88        }
89    }
90}
91
92impl ScalarUDFImpl for ArrayResize {
93    fn as_any(&self) -> &dyn Any {
94        self
95    }
96
97    fn name(&self) -> &str {
98        "array_resize"
99    }
100
101    fn signature(&self) -> &Signature {
102        &self.signature
103    }
104
105    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
106        match &arg_types[0] {
107            List(field) | FixedSizeList(field, _) => Ok(List(Arc::clone(field))),
108            LargeList(field) => Ok(LargeList(Arc::clone(field))),
109            _ => exec_err!(
110                "Not reachable, data_type should be List, LargeList or FixedSizeList"
111            ),
112        }
113    }
114
115    fn invoke_with_args(
116        &self,
117        args: datafusion_expr::ScalarFunctionArgs,
118    ) -> Result<ColumnarValue> {
119        make_scalar_function(array_resize_inner)(&args.args)
120    }
121
122    fn aliases(&self) -> &[String] {
123        &self.aliases
124    }
125
126    fn documentation(&self) -> Option<&Documentation> {
127        self.doc()
128    }
129}
130
131/// array_resize SQL function
132pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result<ArrayRef> {
133    if arg.len() < 2 || arg.len() > 3 {
134        return exec_err!("array_resize needs two or three arguments");
135    }
136
137    let array = &arg[0];
138
139    // Checks if entire array is null
140    if array.null_count() == array.len() {
141        let return_type = match array.data_type() {
142            List(field) => List(Arc::clone(field)),
143            LargeList(field) => LargeList(Arc::clone(field)),
144            _ => {
145                return exec_err!(
146                    "array_resize does not support type '{:?}'.",
147                    array.data_type()
148                )
149            }
150        };
151        return Ok(new_null_array(&return_type, array.len()));
152    }
153
154    let new_len = as_int64_array(&arg[1])?;
155    let new_element = if arg.len() == 3 {
156        Some(Arc::clone(&arg[2]))
157    } else {
158        None
159    };
160
161    match &arg[0].data_type() {
162        List(field) => {
163            let array = as_list_array(&arg[0])?;
164            general_list_resize::<i32>(array, new_len, field, new_element)
165        }
166        LargeList(field) => {
167            let array = as_large_list_array(&arg[0])?;
168            general_list_resize::<i64>(array, new_len, field, new_element)
169        }
170        array_type => exec_err!("array_resize does not support type '{array_type:?}'."),
171    }
172}
173
174/// array_resize keep the original array and append the default element to the end
175fn general_list_resize<O: OffsetSizeTrait + TryInto<i64>>(
176    array: &GenericListArray<O>,
177    count_array: &Int64Array,
178    field: &FieldRef,
179    default_element: Option<ArrayRef>,
180) -> Result<ArrayRef> {
181    let data_type = array.value_type();
182
183    let values = array.values();
184    let original_data = values.to_data();
185
186    // create default element array
187    let default_element = if let Some(default_element) = default_element {
188        default_element
189    } else {
190        let null_scalar = ScalarValue::try_from(&data_type)?;
191        null_scalar.to_array_of_size(original_data.len())?
192    };
193    let default_value_data = default_element.to_data();
194
195    // create a mutable array to store the original data
196    let capacity = Capacities::Array(original_data.len() + default_value_data.len());
197    let mut offsets = vec![O::usize_as(0)];
198    let mut mutable = MutableArrayData::with_capacities(
199        vec![&original_data, &default_value_data],
200        false,
201        capacity,
202    );
203
204    let mut null_builder = NullBufferBuilder::new(array.len());
205
206    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
207        if array.is_null(row_index) {
208            null_builder.append_null();
209            offsets.push(offsets[row_index]);
210            continue;
211        }
212        null_builder.append_non_null();
213
214        let count = count_array.value(row_index).to_usize().ok_or_else(|| {
215            internal_datafusion_err!("array_resize: failed to convert size to usize")
216        })?;
217        let count = O::usize_as(count);
218        let start = offset_window[0];
219        if start + count > offset_window[1] {
220            let extra_count =
221                (start + count - offset_window[1]).try_into().map_err(|_| {
222                    internal_datafusion_err!(
223                        "array_resize: failed to convert size to i64"
224                    )
225                })?;
226            let end = offset_window[1];
227            mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
228            // append default element
229            for _ in 0..extra_count {
230                mutable.extend(1, row_index, row_index + 1);
231            }
232        } else {
233            let end = start + count;
234            mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap());
235        };
236        offsets.push(offsets[row_index] + count);
237    }
238
239    let data = mutable.freeze();
240
241    Ok(Arc::new(GenericListArray::<O>::try_new(
242        Arc::clone(field),
243        OffsetBuffer::<O>::new(offsets.into()),
244        arrow::array::make_array(data),
245        null_builder.finish(),
246    )?))
247}