datafusion_spark/function/array/
shuffle.rs1use arrow::array::{
19 Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
20 OffsetSizeTrait,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType;
24use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
25use arrow::datatypes::FieldRef;
26use datafusion_common::cast::{
27 as_fixed_size_list_array, as_large_list_array, as_list_array,
28};
29use datafusion_common::{
30 Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
31};
32use datafusion_expr::{
33 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl,
34 Signature, TypeSignature, Volatility,
35};
36use rand::rng;
37use rand::rngs::StdRng;
38use rand::{Rng, SeedableRng, seq::SliceRandom};
39use std::any::Any;
40use std::sync::Arc;
41
42#[derive(Debug, PartialEq, Eq, Hash)]
43pub struct SparkShuffle {
44 signature: Signature,
45}
46
47impl Default for SparkShuffle {
48 fn default() -> Self {
49 Self::new()
50 }
51}
52
53impl SparkShuffle {
54 pub fn new() -> Self {
55 Self {
56 signature: Signature {
57 type_signature: TypeSignature::OneOf(vec![
58 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
60 arguments: vec![ArrayFunctionArgument::Array],
61 array_coercion: None,
62 }),
63 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
65 arguments: vec![
66 ArrayFunctionArgument::Array,
67 ArrayFunctionArgument::Index,
68 ],
69 array_coercion: None,
70 }),
71 ]),
72 volatility: Volatility::Volatile,
73 parameter_names: None,
74 },
75 }
76 }
77}
78
79impl ScalarUDFImpl for SparkShuffle {
80 fn as_any(&self) -> &dyn Any {
81 self
82 }
83
84 fn name(&self) -> &str {
85 "shuffle"
86 }
87
88 fn signature(&self) -> &Signature {
89 &self.signature
90 }
91
92 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
93 internal_err!("return_field_from_args should be used instead")
94 }
95
96 fn return_field_from_args(
97 &self,
98 args: datafusion_expr::ReturnFieldArgs,
99 ) -> Result<FieldRef> {
100 Ok(Arc::clone(&args.arg_fields[0]))
102 }
103
104 fn invoke_with_args(
105 &self,
106 args: datafusion_expr::ScalarFunctionArgs,
107 ) -> Result<ColumnarValue> {
108 if args.args.is_empty() {
109 return exec_err!("shuffle expects at least 1 argument");
110 }
111 if args.args.len() > 2 {
112 return exec_err!("shuffle expects at most 2 arguments");
113 }
114
115 let seed = if args.args.len() == 2 {
117 extract_seed(&args.args[1])?
118 } else {
119 None
120 };
121
122 let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
124 array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
125 }
126}
127
128fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
130 match seed_arg {
131 ColumnarValue::Scalar(scalar) => {
132 let seed = match scalar {
133 ScalarValue::Int64(Some(v)) => Some(*v as u64),
134 ScalarValue::Null => None,
135 _ => {
136 return exec_err!(
137 "shuffle seed must be Int64 type, got '{}'",
138 scalar.data_type()
139 );
140 }
141 };
142 Ok(seed)
143 }
144 ColumnarValue::Array(_) => {
145 exec_err!("shuffle seed must be a scalar value, not an array")
146 }
147 }
148}
149
150fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
152 let [input_array] = take_function_args("shuffle", arg)?;
153 match &input_array.data_type() {
154 List(field) => {
155 let array = as_list_array(input_array)?;
156 general_array_shuffle::<i32>(array, field, seed)
157 }
158 LargeList(field) => {
159 let array = as_large_list_array(input_array)?;
160 general_array_shuffle::<i64>(array, field, seed)
161 }
162 FixedSizeList(field, _) => {
163 let array = as_fixed_size_list_array(input_array)?;
164 fixed_size_array_shuffle(array, field, seed)
165 }
166 Null => Ok(Arc::clone(input_array)),
167 array_type => exec_err!("shuffle does not support type '{array_type}'."),
168 }
169}
170
171fn general_array_shuffle<O: OffsetSizeTrait>(
172 array: &GenericListArray<O>,
173 field: &FieldRef,
174 seed: Option<u64>,
175) -> Result<ArrayRef> {
176 let values = array.values();
177 let original_data = values.to_data();
178 let capacity = Capacities::Array(original_data.len());
179 let mut offsets = vec![O::usize_as(0)];
180 let mut nulls = vec![];
181 let mut mutable =
182 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
183 let mut rng = if let Some(s) = seed {
184 StdRng::seed_from_u64(s)
185 } else {
186 let seed = rng().random::<u64>();
188 StdRng::seed_from_u64(seed)
189 };
190
191 for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
192 if array.is_null(row_index) {
194 nulls.push(false);
195 offsets.push(offsets[row_index] + O::one());
196 mutable.extend(0, 0, 1);
197 continue;
198 }
199 nulls.push(true);
200 let start = offset_window[0];
201 let end = offset_window[1];
202 let length = (end - start).to_usize().unwrap();
203
204 let mut indices: Vec<usize> =
206 (start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
207 indices.shuffle(&mut rng);
208
209 for &index in &indices {
211 mutable.extend(0, index, index + 1);
212 }
213
214 offsets.push(offsets[row_index] + O::usize_as(length));
215 }
216
217 let data = mutable.freeze();
218 Ok(Arc::new(GenericListArray::<O>::try_new(
219 Arc::clone(field),
220 OffsetBuffer::<O>::new(offsets.into()),
221 arrow::array::make_array(data),
222 Some(nulls.into()),
223 )?))
224}
225
226fn fixed_size_array_shuffle(
227 array: &FixedSizeListArray,
228 field: &FieldRef,
229 seed: Option<u64>,
230) -> Result<ArrayRef> {
231 let values = array.values();
232 let original_data = values.to_data();
233 let capacity = Capacities::Array(original_data.len());
234 let mut nulls = vec![];
235 let mut mutable =
236 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
237 let value_length = array.value_length() as usize;
238 let mut rng = if let Some(s) = seed {
239 StdRng::seed_from_u64(s)
240 } else {
241 let seed = rng().random::<u64>();
243 StdRng::seed_from_u64(seed)
244 };
245
246 for row_index in 0..array.len() {
247 if array.is_null(row_index) {
249 nulls.push(false);
250 mutable.extend(0, 0, value_length);
251 continue;
252 }
253 nulls.push(true);
254
255 let start = row_index * value_length;
256 let end = start + value_length;
257
258 let mut indices: Vec<usize> = (start..end).collect();
260 indices.shuffle(&mut rng);
261
262 for &index in &indices {
264 mutable.extend(0, index, index + 1);
265 }
266 }
267
268 let data = mutable.freeze();
269 Ok(Arc::new(FixedSizeListArray::try_new(
270 Arc::clone(field),
271 array.value_length(),
272 arrow::array::make_array(data),
273 Some(nulls.into()),
274 )?))
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use arrow::datatypes::Field;
281 use datafusion_expr::ReturnFieldArgs;
282
283 #[test]
284 fn test_shuffle_nullability() {
285 let shuffle = SparkShuffle::new();
286
287 let non_nullable_field = Arc::new(Field::new(
289 "arr",
290 List(Arc::new(Field::new("item", DataType::Int32, true))),
291 false, ));
293
294 let result = shuffle
295 .return_field_from_args(ReturnFieldArgs {
296 arg_fields: &[Arc::clone(&non_nullable_field)],
297 scalar_arguments: &[None],
298 })
299 .unwrap();
300
301 assert!(!result.is_nullable());
303 assert_eq!(result.data_type(), non_nullable_field.data_type());
304
305 let nullable_field = Arc::new(Field::new(
307 "arr",
308 List(Arc::new(Field::new("item", DataType::Int32, true))),
309 true, ));
311
312 let result = shuffle
313 .return_field_from_args(ReturnFieldArgs {
314 arg_fields: &[Arc::clone(&nullable_field)],
315 scalar_arguments: &[None],
316 })
317 .unwrap();
318
319 assert!(result.is_nullable());
321 assert_eq!(result.data_type(), nullable_field.data_type());
322 }
323}