datafusion_comet_spark_expr/array_funcs/
array_repeat.rs1use arrow::array::{
19 new_null_array, Array, ArrayRef, Capacities, GenericListArray, ListArray, MutableArrayData,
20 NullBufferBuilder, OffsetSizeTrait, UInt64Array,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::compute;
24use arrow::compute::cast;
25use arrow::datatypes::DataType::{LargeList, List};
26use arrow::datatypes::{DataType, Field};
27use datafusion::common::cast::{as_large_list_array, as_list_array, as_uint64_array};
28use datafusion::common::{exec_err, DataFusionError, ScalarValue};
29use datafusion::logical_expr::ColumnarValue;
30use std::sync::Arc;
31
32pub fn make_scalar_function<F>(
33 inner: F,
34) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue, DataFusionError>
35where
36 F: Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError>,
37{
38 move |args: &[ColumnarValue]| {
39 let len = args
42 .iter()
43 .fold(Option::<usize>::None, |acc, arg| match arg {
44 ColumnarValue::Scalar(_) => acc,
45 ColumnarValue::Array(a) => Some(a.len()),
46 });
47
48 let is_scalar = len.is_none();
49
50 let args = ColumnarValue::values_to_arrays(args)?;
51
52 let result = (inner)(&args);
53
54 if is_scalar {
55 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
57 result.map(ColumnarValue::Scalar)
58 } else {
59 result.map(ColumnarValue::Array)
60 }
61 }
62}
63
64pub fn spark_array_repeat(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
65 make_scalar_function(spark_array_repeat_inner)(args)
66}
67
68fn spark_array_repeat_inner(args: &[ArrayRef]) -> datafusion::common::Result<ArrayRef> {
70 let element = &args[0];
71 let count_array = &args[1];
72
73 let count_array = match count_array.data_type() {
74 DataType::Int64 => &cast(count_array, &DataType::UInt64)?,
75 DataType::UInt64 => count_array,
76 _ => return exec_err!("count must be an integer type"),
77 };
78
79 let count_array = as_uint64_array(count_array)?;
80
81 match element.data_type() {
82 List(_) => {
83 let list_array = as_list_array(element)?;
84 general_list_repeat::<i32>(list_array, count_array)
85 }
86 LargeList(_) => {
87 let list_array = as_large_list_array(element)?;
88 general_list_repeat::<i64>(list_array, count_array)
89 }
90 _ => general_repeat::<i32>(element, count_array),
91 }
92}
93
94fn general_repeat<O: OffsetSizeTrait>(
107 array: &ArrayRef,
108 count_array: &UInt64Array,
109) -> datafusion::common::Result<ArrayRef> {
110 let data_type = array.data_type();
111 let mut new_values = vec![];
112
113 let count_vec = count_array
114 .values()
115 .to_vec()
116 .iter()
117 .map(|x| *x as usize)
118 .collect::<Vec<_>>();
119
120 let mut nulls = NullBufferBuilder::new(count_array.len());
121
122 for (row_index, &count) in count_vec.iter().enumerate() {
123 nulls.append(!count_array.is_null(row_index));
124 let repeated_array = if array.is_null(row_index) {
125 new_null_array(data_type, count)
126 } else {
127 let original_data = array.to_data();
128 let capacity = Capacities::Array(count);
129 let mut mutable =
130 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
131
132 for _ in 0..count {
133 mutable.extend(0, row_index, row_index + 1);
134 }
135
136 let data = mutable.freeze();
137 arrow::array::make_array(data)
138 };
139 new_values.push(repeated_array);
140 }
141
142 let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
143 let values = compute::concat(&new_values)?;
144
145 Ok(Arc::new(GenericListArray::<O>::try_new(
146 Arc::new(Field::new_list_field(data_type.to_owned(), true)),
147 OffsetBuffer::from_lengths(count_vec),
148 values,
149 nulls.finish(),
150 )?))
151}
152
153fn general_list_repeat<O: OffsetSizeTrait>(
164 list_array: &GenericListArray<O>,
165 count_array: &UInt64Array,
166) -> datafusion::common::Result<ArrayRef> {
167 let data_type = list_array.data_type();
168 let value_type = list_array.value_type();
169 let mut new_values = vec![];
170
171 let count_vec = count_array
172 .values()
173 .to_vec()
174 .iter()
175 .map(|x| *x as usize)
176 .collect::<Vec<_>>();
177
178 for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
179 let list_arr = match list_array_row {
180 Some(list_array_row) => {
181 let original_data = list_array_row.to_data();
182 let capacity = Capacities::Array(original_data.len() * count);
183 let mut mutable =
184 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
185
186 for _ in 0..count {
187 mutable.extend(0, 0, original_data.len());
188 }
189
190 let data = mutable.freeze();
191 let repeated_array = arrow::array::make_array(data);
192
193 let list_arr = GenericListArray::<O>::try_new(
194 Arc::new(Field::new_list_field(value_type.clone(), true)),
195 OffsetBuffer::<O>::from_lengths(vec![original_data.len(); count]),
196 repeated_array,
197 None,
198 )?;
199 Arc::new(list_arr) as ArrayRef
200 }
201 None => new_null_array(data_type, count),
202 };
203 new_values.push(list_arr);
204 }
205
206 let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
207 let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
208 let values = compute::concat(&new_values)?;
209
210 Ok(Arc::new(ListArray::try_new(
211 Arc::new(Field::new_list_field(data_type.to_owned(), true)),
212 OffsetBuffer::<i32>::from_lengths(lengths),
213 values,
214 None,
215 )?))
216}