datafusion_functions_nested/
repeat.rs1use crate::utils::make_scalar_function;
21use arrow::array::{
22 Array, ArrayRef, BooleanBufferBuilder, GenericListArray, Int64Array, OffsetSizeTrait,
23 UInt64Array,
24};
25use arrow::buffer::{NullBuffer, OffsetBuffer};
26use arrow::compute;
27use arrow::datatypes::DataType;
28use arrow::datatypes::{
29 DataType::{LargeList, List},
30 Field,
31};
32use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array};
33use datafusion_common::types::{NativeType, logical_int64};
34use datafusion_common::{DataFusionError, Result};
35use datafusion_expr::{
36 ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility,
37};
38use datafusion_expr_common::signature::{Coercion, TypeSignatureClass};
39use datafusion_macros::user_doc;
40use std::any::Any;
41use std::sync::Arc;
42
43make_udf_expr_and_func!(
44 ArrayRepeat,
45 array_repeat,
46 element count, "returns an array containing element `count` times.", array_repeat_udf );
50
51#[user_doc(
52 doc_section(label = "Array Functions"),
53 description = "Returns an array containing element `count` times.",
54 syntax_example = "array_repeat(element, count)",
55 sql_example = r#"```sql
56> select array_repeat(1, 3);
57+---------------------------------+
58| array_repeat(Int64(1),Int64(3)) |
59+---------------------------------+
60| [1, 1, 1] |
61+---------------------------------+
62> select array_repeat([1, 2], 2);
63+------------------------------------+
64| array_repeat(List([1,2]),Int64(2)) |
65+------------------------------------+
66| [[1, 2], [1, 2]] |
67+------------------------------------+
68```"#,
69 argument(
70 name = "element",
71 description = "Element expression. Can be a constant, column, or function, and any combination of array operators."
72 ),
73 argument(
74 name = "count",
75 description = "Value of how many times to repeat the element."
76 )
77)]
78#[derive(Debug, PartialEq, Eq, Hash)]
79pub struct ArrayRepeat {
80 signature: Signature,
81 aliases: Vec<String>,
82}
83
84impl Default for ArrayRepeat {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl ArrayRepeat {
91 pub fn new() -> Self {
92 Self {
93 signature: Signature::coercible(
94 vec![
95 Coercion::new_exact(TypeSignatureClass::Any),
96 Coercion::new_implicit(
97 TypeSignatureClass::Native(logical_int64()),
98 vec![TypeSignatureClass::Integer],
99 NativeType::Int64,
100 ),
101 ],
102 Volatility::Immutable,
103 ),
104 aliases: vec![String::from("list_repeat")],
105 }
106 }
107}
108
109impl ScalarUDFImpl for ArrayRepeat {
110 fn as_any(&self) -> &dyn Any {
111 self
112 }
113
114 fn name(&self) -> &str {
115 "array_repeat"
116 }
117
118 fn signature(&self) -> &Signature {
119 &self.signature
120 }
121
122 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
123 let element_type = &arg_types[0];
124 match element_type {
125 LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field(
126 element_type.clone(),
127 true,
128 )))),
129 _ => Ok(List(Arc::new(Field::new_list_field(
130 element_type.clone(),
131 true,
132 )))),
133 }
134 }
135
136 fn invoke_with_args(
137 &self,
138 args: datafusion_expr::ScalarFunctionArgs,
139 ) -> Result<ColumnarValue> {
140 make_scalar_function(array_repeat_inner)(&args.args)
141 }
142
143 fn aliases(&self) -> &[String] {
144 &self.aliases
145 }
146
147 fn documentation(&self) -> Option<&Documentation> {
148 self.doc()
149 }
150}
151
152fn array_repeat_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
153 let element = &args[0];
154 let count_array = as_int64_array(&args[1])?;
155
156 match element.data_type() {
157 List(_) => {
158 let list_array = as_list_array(element)?;
159 general_list_repeat::<i32>(list_array, count_array)
160 }
161 LargeList(_) => {
162 let list_array = as_large_list_array(element)?;
163 general_list_repeat::<i64>(list_array, count_array)
164 }
165 _ => general_repeat::<i32>(element, count_array),
166 }
167}
168
169fn general_repeat<O: OffsetSizeTrait>(
182 array: &ArrayRef,
183 count_array: &Int64Array,
184) -> Result<ArrayRef> {
185 let total_repeated_values: usize = (0..count_array.len())
186 .map(|i| get_count_with_validity(count_array, i))
187 .sum();
188
189 let mut take_indices = Vec::with_capacity(total_repeated_values);
190 let mut offsets = Vec::with_capacity(count_array.len() + 1);
191 offsets.push(O::zero());
192 let mut running_offset = 0usize;
193
194 for idx in 0..count_array.len() {
195 let count = get_count_with_validity(count_array, idx);
196 running_offset = running_offset.checked_add(count).ok_or_else(|| {
197 DataFusionError::Execution(
198 "array_repeat: running_offset overflowed usize".to_string(),
199 )
200 })?;
201 let offset = O::from_usize(running_offset).ok_or_else(|| {
202 DataFusionError::Execution(format!(
203 "array_repeat: offset {running_offset} exceeds the maximum value for offset type"
204 ))
205 })?;
206 offsets.push(offset);
207 take_indices.extend(std::iter::repeat_n(idx as u64, count));
208 }
209
210 let repeated_values = compute::take(
212 array.as_ref(),
213 &UInt64Array::from_iter_values(take_indices),
214 None,
215 )?;
216
217 Ok(Arc::new(GenericListArray::<O>::try_new(
219 Arc::new(Field::new_list_field(array.data_type().to_owned(), true)),
220 OffsetBuffer::new(offsets.into()),
221 repeated_values,
222 count_array.nulls().cloned(),
223 )?))
224}
225
226fn general_list_repeat<O: OffsetSizeTrait>(
237 list_array: &GenericListArray<O>,
238 count_array: &Int64Array,
239) -> Result<ArrayRef> {
240 let list_offsets = list_array.value_offsets();
241
242 let mut outer_total = 0usize;
244 let mut inner_total = 0usize;
245 for i in 0..count_array.len() {
246 let count = get_count_with_validity(count_array, i);
247 if count > 0 {
248 outer_total += count;
249 if list_array.is_valid(i) {
250 let len = list_offsets[i + 1].to_usize().unwrap()
251 - list_offsets[i].to_usize().unwrap();
252 inner_total += len * count;
253 }
254 }
255 }
256
257 let mut inner_offsets = Vec::with_capacity(outer_total + 1);
259 let mut take_indices = Vec::with_capacity(inner_total);
260 let mut inner_nulls = BooleanBufferBuilder::new(outer_total);
261 let mut inner_running = 0usize;
262 inner_offsets.push(O::zero());
263
264 for row_idx in 0..count_array.len() {
265 let count = get_count_with_validity(count_array, row_idx);
266 let list_is_valid = list_array.is_valid(row_idx);
267 let start = list_offsets[row_idx].to_usize().unwrap();
268 let end = list_offsets[row_idx + 1].to_usize().unwrap();
269 let row_len = end - start;
270
271 for _ in 0..count {
272 inner_running = inner_running.checked_add(row_len).ok_or_else(|| {
273 DataFusionError::Execution(
274 "array_repeat: inner offset overflowed usize".to_string(),
275 )
276 })?;
277 let offset = O::from_usize(inner_running).ok_or_else(|| {
278 DataFusionError::Execution(format!(
279 "array_repeat: offset {inner_running} exceeds the maximum value for offset type"
280 ))
281 })?;
282 inner_offsets.push(offset);
283 inner_nulls.append(list_is_valid);
284 if list_is_valid {
285 take_indices.extend(start as u64..end as u64);
286 }
287 }
288 }
289
290 let inner_values = compute::take(
292 list_array.values().as_ref(),
293 &UInt64Array::from_iter_values(take_indices),
294 None,
295 )?;
296 let inner_list = GenericListArray::<O>::try_new(
297 Arc::new(Field::new_list_field(list_array.value_type().clone(), true)),
298 OffsetBuffer::new(inner_offsets.into()),
299 inner_values,
300 Some(NullBuffer::new(inner_nulls.finish())),
301 )?;
302
303 Ok(Arc::new(GenericListArray::<O>::try_new(
305 Arc::new(Field::new_list_field(
306 list_array.data_type().to_owned(),
307 true,
308 )),
309 OffsetBuffer::<O>::from_lengths(
310 count_array
311 .iter()
312 .map(|c| c.map(|v| if v > 0 { v as usize } else { 0 }).unwrap_or(0)),
313 ),
314 Arc::new(inner_list),
315 count_array.nulls().cloned(),
316 )?))
317}
318
319#[inline]
322fn get_count_with_validity(count_array: &Int64Array, idx: usize) -> usize {
323 if count_array.is_null(idx) {
324 0
325 } else {
326 let c = count_array.value(idx);
327 if c > 0 { c as usize } else { 0 }
328 }
329}