datafusion_comet_spark_expr/array_funcs/
array_repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData,
20    NullBufferBuilder, OffsetSizeTrait, UInt64Array,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::compute;
24use arrow::compute::cast;
25use arrow::datatypes::DataType::{LargeList, List};
26use arrow::datatypes::{DataType, Field};
27use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array};
28use datafusion::common::{exec_err, DataFusionError, ScalarValue};
29use datafusion::logical_expr::ColumnarValue;
30use std::sync::Arc;
31
32pub fn make_scalar_function<F>(
33    inner: F,
34) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError>
35where
36    F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>,
37{
38    move |args: &[ColumnarValue]| {
39        // first, identify if any of the arguments is an Array. If yes, store its `len`,
40        // as any scalar will need to be converted to an array of len `len`.
41        let len = args
42            .iter()
43            .fold(Option::<usize>::None, |acc, arg| match arg {
44                ColumnarValue::Scalar(_) => acc,
45                ColumnarValue::Array(a) => Some(a.len()),
46            });
47
48        let is_scalar = len.is_none();
49
50        let args = ColumnarValue::values_to_arrays(args)?;
51
52        let result = (inner)(&args);
53
54        if is_scalar {
55            // If all inputs are scalar, keeps output as scalar
56            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
57            result.map(ColumnarValue::Scalar)
58        } else {
59            result.map(ColumnarValue::Array)
60        }
61    }
62}
63
64pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
65    make_scalar_function(spark_array_repeat_inner)(args)
66}
67
68/// Array_repeat SQL function
69fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> {
70    let element = &args[0];
71    let count_array = &args[1];
72
73    let count_array = match count_array.data_type() {
74        DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
75        DataType::UInt64 => count_array,
76        _ => return exec_err!("count must be an integer type"),
77    };
78
79    let count_array = as_uint64_array(count_array)?;
80
81    match element.data_type() {
82        List(_) => {
83            let list_array = as_list_array(element)?;
84            general_list_repeat::<i32>(list_array, count_array)
85        }
86        LargeList(_) => {
87            let list_array = as_large_list_array(element)?;
88            general_list_repeat::<i64>(list_array, count_array)
89        }
90        _ => general_repeat::<i32>(element, count_array),
91    }
92}
93
94/// For each element of `array[i]` repeat `count_array[i]` times.
95///
96/// Assumption for the input:
97///     1. `count[i] >= 0`
98///     2. `array.len() == count_array.len()`
99///
100/// For example,
101/// ```text
102/// array_repeat(
103///     [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
104/// )
105/// ```
106fn general_repeat<O: OffsetSizeTrait>(
107    array: &ArrayRef,
108    count_array: &UInt64Array,
109) -> datafusion::common::Result<ArrayRef> {
110    let data_type = array.data_type();
111    let mut new_values = vec![];
112
113    let count_vec = count_array
114        .values()
115        .to_vec()
116        .iter()
117        .map(|x| *x as usize)
118        .collect::<Vec<_>>();
119
120    let mut nulls = NullBufferBuilder::new(count_array.len());
121
122    for (row_index, &count) in count_vec.iter().enumerate() {
123        nulls.append(!count_array.is_null(row_index));
124        let repeated_array = if array.is_null(row_index) {
125            new_null_array(data_type, count)
126        } else {
127            let original_data = array.to_data();
128            let capacity = Capacities::Array(count);
129            let mut mutable =
130                MutableArrayData::with_capacities(vec![&original_data], false, capacity);
131
132            for _ in 0..count {
133                mutable.extend(0, row_index, row_index + 1);
134            }
135
136            let data = mutable.freeze();
137            arrow::array::make_array(data)
138        };
139        new_values.push(repeated_array);
140    }
141
142    let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
143    let values = compute::concat(&new_values)?;
144
145    Ok(Arc::new(GenericListArray::<O>::try_new(
146        Arc::new(Field::new_list_field(data_type.to_owned(), true)),
147        OffsetBuffer::from_lengths(count_vec),
148        values,
149        nulls.finish(),
150    )?))
151}
152
153/// Handle List version of `general_repeat`
154///
155/// For each element of `list_array[i]` repeat `count_array[i]` times.
156///
157/// For example,
158/// ```text
159/// array_repeat(
160///     [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]]
161/// )
162/// ```
163fn general_list_repeat<O: OffsetSizeTrait>(
164    list_array: &GenericListArray<O>,
165    count_array: &UInt64Array,
166) -> datafusion::common::Result<ArrayRef> {
167    let data_type = list_array.data_type();
168    let value_type = list_array.value_type();
169    let mut new_values = vec![];
170
171    let count_vec = count_array
172        .values()
173        .to_vec()
174        .iter()
175        .map(|x| *x as usize)
176        .collect::<Vec<_>>();
177
178    for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
179        let list_arr = match list_array_row {
180            Some(list_array_row) => {
181                let original_data = list_array_row.to_data();
182                let capacity = Capacities::Array(original_data.len() * count);
183                let mut mutable =
184                    MutableArrayData::with_capacities(vec![&original_data], false, capacity);
185
186                for _ in 0..count {
187                    mutable.extend(0, 0, original_data.len());
188                }
189
190                let data = mutable.freeze();
191                let repeated_array = arrow::array::make_array(data);
192
193                let list_arr = GenericListArray::<O>::try_new(
194                    Arc::new(Field::new_list_field(value_type.clone(), true)),
195                    OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]),
196                    repeated_array,
197                    None,
198                )?;
199                Arc::new(list_arr) as ArrayRef
200            }
201            None => new_null_array(data_type, count),
202        };
203        new_values.push(list_arr);
204    }
205
206    let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
207    let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
208    let values = compute::concat(&new_values)?;
209
210    Ok(Arc::new(ListArray::try_new(
211        Arc::new(Field::new_list_field(data_type.to_owned(), true)),
212        OffsetBuffer::<i32>::from_lengths(lengths),
213        values,
214        None,
215    )?))
216}