datafusion_spark/function/array/
shuffle.rs1use arrow::array::{
19 Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
20 OffsetSizeTrait,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType;
24use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
25use arrow::datatypes::FieldRef;
26use datafusion_common::cast::{
27 as_fixed_size_list_array, as_large_list_array, as_list_array,
28};
29use datafusion_common::{
30 Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
31};
32use datafusion_expr::{
33 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarFunctionArgs,
34 ScalarUDFImpl, Signature, TypeSignature, Volatility,
35};
36use rand::rng;
37use rand::rngs::StdRng;
38use rand::{Rng, SeedableRng, seq::SliceRandom};
39use std::sync::Arc;
40
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkShuffle {
43 signature: Signature,
44}
45
46impl Default for SparkShuffle {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl SparkShuffle {
53 pub fn new() -> Self {
54 Self {
55 signature: Signature {
56 type_signature: TypeSignature::OneOf(vec![
57 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
59 arguments: vec![ArrayFunctionArgument::Array],
60 array_coercion: None,
61 }),
62 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
64 arguments: vec![
65 ArrayFunctionArgument::Array,
66 ArrayFunctionArgument::Index,
67 ],
68 array_coercion: None,
69 }),
70 ]),
71 volatility: Volatility::Volatile,
72 parameter_names: None,
73 },
74 }
75 }
76}
77
78impl ScalarUDFImpl for SparkShuffle {
79 fn name(&self) -> &str {
80 "shuffle"
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
88 internal_err!("return_field_from_args should be used instead")
89 }
90
91 fn return_field_from_args(
92 &self,
93 args: datafusion_expr::ReturnFieldArgs,
94 ) -> Result<FieldRef> {
95 Ok(Arc::clone(&args.arg_fields[0]))
97 }
98
99 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100 if args.args.is_empty() || args.args.len() > 2 {
101 return exec_err!("shuffle expects 1 or 2 argument(s)");
102 }
103
104 let seed = if args.args.len() == 2 {
106 extract_seed(&args.args[1])?
107 } else {
108 None
109 };
110
111 let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
113 array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
114 }
115}
116
117fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
119 match seed_arg {
120 ColumnarValue::Scalar(scalar) => {
121 let seed = match scalar {
122 ScalarValue::Int64(Some(v)) => Some(*v as u64),
123 ScalarValue::Null | ScalarValue::Int64(None) => None,
124 _ => {
125 return exec_err!(
126 "shuffle seed must be Int64 type but got '{}'",
127 scalar.data_type()
128 );
129 }
130 };
131 Ok(seed)
132 }
133 ColumnarValue::Array(_) => {
134 exec_err!("shuffle seed must be a scalar value, not an array")
135 }
136 }
137}
138
139fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
141 let [input_array] = take_function_args("shuffle", arg)?;
142 match &input_array.data_type() {
143 List(field) => {
144 let array = as_list_array(input_array)?;
145 general_array_shuffle::<i32>(array, field, seed)
146 }
147 LargeList(field) => {
148 let array = as_large_list_array(input_array)?;
149 general_array_shuffle::<i64>(array, field, seed)
150 }
151 FixedSizeList(field, _) => {
152 let array = as_fixed_size_list_array(input_array)?;
153 fixed_size_array_shuffle(array, field, seed)
154 }
155 Null => Ok(Arc::clone(input_array)),
156 array_type => exec_err!(
157 "shuffle does not support type '{array_type}'; \
158 expected types: List, LargeList, FixedSizeList or Null."
159 ),
160 }
161}
162
163fn general_array_shuffle<O: OffsetSizeTrait>(
164 array: &GenericListArray<O>,
165 field: &FieldRef,
166 seed: Option<u64>,
167) -> Result<ArrayRef> {
168 let values = array.values();
169 let original_data = values.to_data();
170 let capacity = Capacities::Array(original_data.len());
171 let mut offsets = vec![O::usize_as(0)];
172 let mut nulls = vec![];
173 let mut mutable =
174 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
175 let mut rng = if let Some(s) = seed {
176 StdRng::seed_from_u64(s)
177 } else {
178 let seed = rng().random::<u64>();
180 StdRng::seed_from_u64(seed)
181 };
182
183 for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
184 if array.is_null(row_index) {
186 nulls.push(false);
187 offsets.push(offsets[row_index] + O::one());
188 mutable.extend(0, 0, 1);
189 continue;
190 }
191 nulls.push(true);
192 let start = offset_window[0];
193 let end = offset_window[1];
194 let length = (end - start).to_usize().unwrap();
195
196 let mut indices: Vec<usize> =
198 (start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
199 indices.shuffle(&mut rng);
200
201 for &index in &indices {
203 mutable.extend(0, index, index + 1);
204 }
205
206 offsets.push(offsets[row_index] + O::usize_as(length));
207 }
208
209 let data = mutable.freeze();
210 Ok(Arc::new(GenericListArray::<O>::try_new(
211 Arc::clone(field),
212 OffsetBuffer::<O>::new(offsets.into()),
213 arrow::array::make_array(data),
214 Some(nulls.into()),
215 )?))
216}
217
218fn fixed_size_array_shuffle(
219 array: &FixedSizeListArray,
220 field: &FieldRef,
221 seed: Option<u64>,
222) -> Result<ArrayRef> {
223 let values = array.values();
224 let original_data = values.to_data();
225 let capacity = Capacities::Array(original_data.len());
226 let mut nulls = vec![];
227 let mut mutable =
228 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
229 let value_length = array.value_length() as usize;
230 let mut rng = if let Some(s) = seed {
231 StdRng::seed_from_u64(s)
232 } else {
233 let seed = rng().random::<u64>();
235 StdRng::seed_from_u64(seed)
236 };
237
238 for row_index in 0..array.len() {
239 if array.is_null(row_index) {
241 nulls.push(false);
242 mutable.extend(0, 0, value_length);
243 continue;
244 }
245 nulls.push(true);
246
247 let start = row_index * value_length;
248 let end = start + value_length;
249
250 let mut indices: Vec<usize> = (start..end).collect();
252 indices.shuffle(&mut rng);
253
254 for &index in &indices {
256 mutable.extend(0, index, index + 1);
257 }
258 }
259
260 let data = mutable.freeze();
261 Ok(Arc::new(FixedSizeListArray::try_new(
262 Arc::clone(field),
263 array.value_length(),
264 arrow::array::make_array(data),
265 Some(nulls.into()),
266 )?))
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272 use arrow::datatypes::Field;
273 use datafusion_expr::ReturnFieldArgs;
274
275 #[test]
276 fn test_shuffle_nullability() {
277 let shuffle = SparkShuffle::new();
278
279 let non_nullable_field = Arc::new(Field::new(
281 "arr",
282 List(Arc::new(Field::new("item", DataType::Int32, true))),
283 false, ));
285
286 let result = shuffle
287 .return_field_from_args(ReturnFieldArgs {
288 arg_fields: &[Arc::clone(&non_nullable_field)],
289 scalar_arguments: &[None],
290 })
291 .unwrap();
292
293 assert!(!result.is_nullable());
295 assert_eq!(result.data_type(), non_nullable_field.data_type());
296
297 let nullable_field = Arc::new(Field::new(
299 "arr",
300 List(Arc::new(Field::new("item", DataType::Int32, true))),
301 true, ));
303
304 let result = shuffle
305 .return_field_from_args(ReturnFieldArgs {
306 arg_fields: &[Arc::clone(&nullable_field)],
307 scalar_arguments: &[None],
308 })
309 .unwrap();
310
311 assert!(result.is_nullable());
313 assert_eq!(result.data_type(), nullable_field.data_type());
314 }
315}