Skip to main content

datafusion_functions_nested/
repeat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! [`ScalarUDFImpl`] definitions for array_repeat function.
19
20use crate::utils::make_scalar_function;
21use arrow::array::{
22    Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait,
23    UInt64Array,
24};
25use arrow::buffer::{NullBuffer, OffsetBuffer};
26use arrow::compute;
27use arrow::datatypes::DataType;
28use arrow::datatypes::{
29    DataType::{LargeList, List},
30    Field,
31};
32use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
33use datafusion_common::types::{NativeType, logical_int64};
34use datafusion_common::{DataFusionError, Result};
35use datafusion_expr::{
36    ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
37};
38use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
39use datafusion_macros::user_doc;
40use std::any::Any;
41use std::sync::Arc;
42
43make_udf_expr_and_func!(
44    ArrayRepeat,
45    array_repeat,
46    element count, // arg name
47    "returns an array containing element `count` times.", // doc
48    array_repeat_udf // internal function name
49);
50
51#[user_doc(
52    doc_section(label = "Array Functions"),
53    description = "Returns an array containing element `count` times.",
54    syntax_example = "array_repeat(element, count)",
55    sql_example = r#"```sql
56> select array_repeat(1, 3);
57+---------------------------------+
58| array_repeat(Int64(1),Int64(3)) |
59+---------------------------------+
60| [1, 1, 1]                       |
61+---------------------------------+
62> select array_repeat([1, 2], 2);
63+------------------------------------+
64| array_repeat(List([1,2]),Int64(2)) |
65+------------------------------------+
66| [[1, 2], [1, 2]]                   |
67+------------------------------------+
68```"#,
69    argument(
70        name = "element",
71        description = "Element expression. Can be a constant, column, or function, and any combination of array operators."
72    ),
73    argument(
74        name = "count",
75        description = "Value of how many times to repeat the element."
76    )
77)]
78#[derive(Debug, PartialEq, Eq, Hash)]
79pub struct ArrayRepeat {
80    signature: Signature,
81    aliases: Vec<String>,
82}
83
84impl Default for ArrayRepeat {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl ArrayRepeat {
91    pub fn new() -> Self {
92        Self {
93            signature: Signature::coercible(
94                vec![
95                    Coercion::new_exact(TypeSignatureClass::Any),
96                    Coercion::new_implicit(
97                        TypeSignatureClass::Native(logical_int64()),
98                        vec![TypeSignatureClass::Integer],
99                        NativeType::Int64,
100                    ),
101                ],
102                Volatility::Immutable,
103            ),
104            aliases: vec![String::from("list_repeat")],
105        }
106    }
107}
108
109impl ScalarUDFImpl for ArrayRepeat {
110    fn as_any(&self) -> &dyn Any {
111        self
112    }
113
114    fn name(&self) -> &str {
115        "array_repeat"
116    }
117
118    fn signature(&self) -> &Signature {
119        &self.signature
120    }
121
122    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
123        let element_type = &arg_types[0];
124        match element_type {
125            LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field(
126                element_type.clone(),
127                true,
128            )))),
129            _ => Ok(List(Arc::new(Field::new_list_field(
130                element_type.clone(),
131                true,
132            )))),
133        }
134    }
135
136    fn invoke_with_args(
137        &self,
138        args: datafusion_expr::ScalarFunctionArgs,
139    ) -> Result<ColumnarValue> {
140        make_scalar_function(array_repeat_inner)(&args.args)
141    }
142
143    fn aliases(&self) -> &[String] {
144        &self.aliases
145    }
146
147    fn documentation(&self) -> Option<&Documentation> {
148        self.doc()
149    }
150}
151
152fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
153    let element = &args[0];
154    let count_array = as_int64_array(&args[1])?;
155
156    match element.data_type() {
157        List(_) => {
158            let list_array = as_list_array(element)?;
159            general_list_repeat::<i32>(list_array, count_array)
160        }
161        LargeList(_) => {
162            let list_array = as_large_list_array(element)?;
163            general_list_repeat::<i64>(list_array, count_array)
164        }
165        _ => general_repeat::<i32>(element, count_array),
166    }
167}
168
169/// For each element of `array[i]` repeat `count_array[i]` times.
170///
171/// Assumption for the input:
172///     1. `count[i] >= 0`
173///     2. `array.len() == count_array.len()`
174///
175/// For example,
176/// ```text
177/// array_repeat(
178///     [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
179/// )
180/// ```
181fn general_repeat<O: OffsetSizeTrait>(
182    array: &ArrayRef,
183    count_array: &Int64Array,
184) -> Result<ArrayRef> {
185    let total_repeated_values: usize = (0..count_array.len())
186        .map(|i| get_count_with_validity(count_array, i))
187        .sum();
188
189    let mut take_indices = Vec::with_capacity(total_repeated_values);
190    let mut offsets = Vec::with_capacity(count_array.len() + 1);
191    offsets.push(O::zero());
192    let mut running_offset = 0usize;
193
194    for idx in 0..count_array.len() {
195        let count = get_count_with_validity(count_array, idx);
196        running_offset = running_offset.checked_add(count).ok_or_else(|| {
197            DataFusionError::Execution(
198                "array_repeat: running_offset overflowed usize".to_string(),
199            )
200        })?;
201        let offset = O::from_usize(running_offset).ok_or_else(|| {
202            DataFusionError::Execution(format!(
203                "array_repeat: offset {running_offset} exceeds the maximum value for offset type"
204            ))
205        })?;
206        offsets.push(offset);
207        take_indices.extend(std::iter::repeat_n(idx as u64, count));
208    }
209
210    // Build the flattened values
211    let repeated_values = compute::take(
212        array.as_ref(),
213        &UInt64Array::from_iter_values(take_indices),
214        None,
215    )?;
216
217    // Construct final ListArray
218    Ok(Arc::new(GenericListArray::<O>::try_new(
219        Arc::new(Field::new_list_field(array.data_type().to_owned(), true)),
220        OffsetBuffer::new(offsets.into()),
221        repeated_values,
222        count_array.nulls().cloned(),
223    )?))
224}
225
226/// Handle List version of `general_repeat`
227///
228/// For each element of `list_array[i]` repeat `count_array[i]` times.
229///
230/// For example,
231/// ```text
232/// array_repeat(
233///     [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]]
234/// )
235/// ```
236fn general_list_repeat<O: OffsetSizeTrait>(
237    list_array: &GenericListArray<O>,
238    count_array: &Int64Array,
239) -> Result<ArrayRef> {
240    let list_offsets = list_array.value_offsets();
241
242    // calculate capacities for pre-allocation
243    let mut outer_total = 0usize;
244    let mut inner_total = 0usize;
245    for i in 0..count_array.len() {
246        let count = get_count_with_validity(count_array, i);
247        if count > 0 {
248            outer_total += count;
249            if list_array.is_valid(i) {
250                let len = list_offsets[i + 1].to_usize().unwrap()
251                    - list_offsets[i].to_usize().unwrap();
252                inner_total += len * count;
253            }
254        }
255    }
256
257    // Build inner structures
258    let mut inner_offsets = Vec::with_capacity(outer_total + 1);
259    let mut take_indices = Vec::with_capacity(inner_total);
260    let mut inner_nulls = BooleanBufferBuilder::new(outer_total);
261    let mut inner_running = 0usize;
262    inner_offsets.push(O::zero());
263
264    for row_idx in 0..count_array.len() {
265        let count = get_count_with_validity(count_array, row_idx);
266        let list_is_valid = list_array.is_valid(row_idx);
267        let start = list_offsets[row_idx].to_usize().unwrap();
268        let end = list_offsets[row_idx + 1].to_usize().unwrap();
269        let row_len = end - start;
270
271        for _ in 0..count {
272            inner_running = inner_running.checked_add(row_len).ok_or_else(|| {
273                DataFusionError::Execution(
274                    "array_repeat: inner offset overflowed usize".to_string(),
275                )
276            })?;
277            let offset = O::from_usize(inner_running).ok_or_else(|| {
278                DataFusionError::Execution(format!(
279                    "array_repeat: offset {inner_running} exceeds the maximum value for offset type"
280                ))
281            })?;
282            inner_offsets.push(offset);
283            inner_nulls.append(list_is_valid);
284            if list_is_valid {
285                take_indices.extend(start as u64..end as u64);
286            }
287        }
288    }
289
290    // Build inner ListArray
291    let inner_values = compute::take(
292        list_array.values().as_ref(),
293        &UInt64Array::from_iter_values(take_indices),
294        None,
295    )?;
296    let inner_list = GenericListArray::<O>::try_new(
297        Arc::new(Field::new_list_field(list_array.value_type().clone(), true)),
298        OffsetBuffer::new(inner_offsets.into()),
299        inner_values,
300        Some(NullBuffer::new(inner_nulls.finish())),
301    )?;
302
303    // Build outer ListArray
304    Ok(Arc::new(GenericListArray::<O>::try_new(
305        Arc::new(Field::new_list_field(
306            list_array.data_type().to_owned(),
307            true,
308        )),
309        OffsetBuffer::<O>::from_lengths(
310            count_array
311                .iter()
312                .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)),
313        ),
314        Arc::new(inner_list),
315        count_array.nulls().cloned(),
316    )?))
317}
318
319/// Helper function to get count from count_array at given index
320/// Return 0 for null values or non-positive count.
321#[inline]
322fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize {
323    if count_array.is_null(idx) {
324        0
325    } else {
326        let c = count_array.value(idx);
327        if c > 0 { c as usize } else { 0 }
328    }
329}