datafusion_spark/function/array/
shuffle.rs1use arrow::array::{
19 Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
20 OffsetSizeTrait,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType;
24use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
25use arrow::datatypes::FieldRef;
26use datafusion_common::cast::{
27 as_fixed_size_list_array, as_large_list_array, as_list_array,
28};
29use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue};
30use datafusion_expr::{
31 ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl,
32 Signature, TypeSignature, Volatility,
33};
34use rand::rng;
35use rand::rngs::StdRng;
36use rand::{seq::SliceRandom, Rng, SeedableRng};
37use std::any::Any;
38use std::sync::Arc;
39
40#[derive(Debug, PartialEq, Eq, Hash)]
41pub struct SparkShuffle {
42 signature: Signature,
43}
44
45impl Default for SparkShuffle {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl SparkShuffle {
52 pub fn new() -> Self {
53 Self {
54 signature: Signature {
55 type_signature: TypeSignature::OneOf(vec![
56 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
58 arguments: vec![ArrayFunctionArgument::Array],
59 array_coercion: None,
60 }),
61 TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
63 arguments: vec![
64 ArrayFunctionArgument::Array,
65 ArrayFunctionArgument::Index,
66 ],
67 array_coercion: None,
68 }),
69 ]),
70 volatility: Volatility::Volatile,
71 parameter_names: None,
72 },
73 }
74 }
75}
76
77impl ScalarUDFImpl for SparkShuffle {
78 fn as_any(&self) -> &dyn Any {
79 self
80 }
81
82 fn name(&self) -> &str {
83 "shuffle"
84 }
85
86 fn signature(&self) -> &Signature {
87 &self.signature
88 }
89
90 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
91 Ok(arg_types[0].clone())
92 }
93
94 fn invoke_with_args(
95 &self,
96 args: datafusion_expr::ScalarFunctionArgs,
97 ) -> Result<ColumnarValue> {
98 if args.args.is_empty() {
99 return exec_err!("shuffle expects at least 1 argument");
100 }
101 if args.args.len() > 2 {
102 return exec_err!("shuffle expects at most 2 arguments");
103 }
104
105 let seed = if args.args.len() == 2 {
107 extract_seed(&args.args[1])?
108 } else {
109 None
110 };
111
112 let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
114 array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
115 }
116}
117
118fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
120 match seed_arg {
121 ColumnarValue::Scalar(scalar) => {
122 let seed = match scalar {
123 ScalarValue::Int64(Some(v)) => Some(*v as u64),
124 ScalarValue::Null => None,
125 _ => {
126 return exec_err!(
127 "shuffle seed must be Int64 type, got '{}'",
128 scalar.data_type()
129 );
130 }
131 };
132 Ok(seed)
133 }
134 ColumnarValue::Array(_) => {
135 exec_err!("shuffle seed must be a scalar value, not an array")
136 }
137 }
138}
139
140fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
142 let [input_array] = take_function_args("shuffle", arg)?;
143 match &input_array.data_type() {
144 List(field) => {
145 let array = as_list_array(input_array)?;
146 general_array_shuffle::<i32>(array, field, seed)
147 }
148 LargeList(field) => {
149 let array = as_large_list_array(input_array)?;
150 general_array_shuffle::<i64>(array, field, seed)
151 }
152 FixedSizeList(field, _) => {
153 let array = as_fixed_size_list_array(input_array)?;
154 fixed_size_array_shuffle(array, field, seed)
155 }
156 Null => Ok(Arc::clone(input_array)),
157 array_type => exec_err!("shuffle does not support type '{array_type}'."),
158 }
159}
160
161fn general_array_shuffle<O: OffsetSizeTrait>(
162 array: &GenericListArray<O>,
163 field: &FieldRef,
164 seed: Option<u64>,
165) -> Result<ArrayRef> {
166 let values = array.values();
167 let original_data = values.to_data();
168 let capacity = Capacities::Array(original_data.len());
169 let mut offsets = vec![O::usize_as(0)];
170 let mut nulls = vec![];
171 let mut mutable =
172 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
173 let mut rng = if let Some(s) = seed {
174 StdRng::seed_from_u64(s)
175 } else {
176 let seed = rng().random::<u64>();
178 StdRng::seed_from_u64(seed)
179 };
180
181 for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
182 if array.is_null(row_index) {
184 nulls.push(false);
185 offsets.push(offsets[row_index] + O::one());
186 mutable.extend(0, 0, 1);
187 continue;
188 }
189 nulls.push(true);
190 let start = offset_window[0];
191 let end = offset_window[1];
192 let length = (end - start).to_usize().unwrap();
193
194 let mut indices: Vec<usize> =
196 (start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
197 indices.shuffle(&mut rng);
198
199 for &index in &indices {
201 mutable.extend(0, index, index + 1);
202 }
203
204 offsets.push(offsets[row_index] + O::usize_as(length));
205 }
206
207 let data = mutable.freeze();
208 Ok(Arc::new(GenericListArray::<O>::try_new(
209 Arc::clone(field),
210 OffsetBuffer::<O>::new(offsets.into()),
211 arrow::array::make_array(data),
212 Some(nulls.into()),
213 )?))
214}
215
216fn fixed_size_array_shuffle(
217 array: &FixedSizeListArray,
218 field: &FieldRef,
219 seed: Option<u64>,
220) -> Result<ArrayRef> {
221 let values = array.values();
222 let original_data = values.to_data();
223 let capacity = Capacities::Array(original_data.len());
224 let mut nulls = vec![];
225 let mut mutable =
226 MutableArrayData::with_capacities(vec![&original_data], false, capacity);
227 let value_length = array.value_length() as usize;
228 let mut rng = if let Some(s) = seed {
229 StdRng::seed_from_u64(s)
230 } else {
231 let seed = rng().random::<u64>();
233 StdRng::seed_from_u64(seed)
234 };
235
236 for row_index in 0..array.len() {
237 if array.is_null(row_index) {
239 nulls.push(false);
240 mutable.extend(0, 0, value_length);
241 continue;
242 }
243 nulls.push(true);
244
245 let start = row_index * value_length;
246 let end = start + value_length;
247
248 let mut indices: Vec<usize> = (start..end).collect();
250 indices.shuffle(&mut rng);
251
252 for &index in &indices {
254 mutable.extend(0, index, index + 1);
255 }
256 }
257
258 let data = mutable.freeze();
259 Ok(Arc::new(FixedSizeListArray::try_new(
260 Arc::clone(field),
261 array.value_length(),
262 arrow::array::make_array(data),
263 Some(nulls.into()),
264 )?))
265}