Skip to main content

datafusion_spark/function/array/
shuffle.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
20    OffsetSizeTrait,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType;
24use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
25use arrow::datatypes::FieldRef;
26use datafusion_common::cast::{
27    as_fixed_size_list_array, as_large_list_array, as_list_array,
28};
29use datafusion_common::{
30    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
31};
32use datafusion_expr::{
33    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl,
34    Signature, TypeSignature, Volatility,
35};
36use rand::rng;
37use rand::rngs::StdRng;
38use rand::{Rng, SeedableRng, seq::SliceRandom};
39use std::any::Any;
40use std::sync::Arc;
41
42#[derive(Debug, PartialEq, Eq, Hash)]
43pub struct SparkShuffle {
44    signature: Signature,
45}
46
47impl Default for SparkShuffle {
48    fn default() -> Self {
49        Self::new()
50    }
51}
52
53impl SparkShuffle {
54    pub fn new() -> Self {
55        Self {
56            signature: Signature {
57                type_signature: TypeSignature::OneOf(vec![
58                    // Only array argument
59                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
60                        arguments: vec![ArrayFunctionArgument::Array],
61                        array_coercion: None,
62                    }),
63                    // Array + Index (seed) argument
64                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
65                        arguments: vec![
66                            ArrayFunctionArgument::Array,
67                            ArrayFunctionArgument::Index,
68                        ],
69                        array_coercion: None,
70                    }),
71                ]),
72                volatility: Volatility::Volatile,
73                parameter_names: None,
74            },
75        }
76    }
77}
78
79impl ScalarUDFImpl for SparkShuffle {
80    fn as_any(&self) -> &dyn Any {
81        self
82    }
83
84    fn name(&self) -> &str {
85        "shuffle"
86    }
87
88    fn signature(&self) -> &Signature {
89        &self.signature
90    }
91
92    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
93        internal_err!("return_field_from_args should be used instead")
94    }
95
96    fn return_field_from_args(
97        &self,
98        args: datafusion_expr::ReturnFieldArgs,
99    ) -> Result<FieldRef> {
100        // Shuffle returns an array with the same type and nullability as the input
101        Ok(Arc::clone(&args.arg_fields[0]))
102    }
103
104    fn invoke_with_args(
105        &self,
106        args: datafusion_expr::ScalarFunctionArgs,
107    ) -> Result<ColumnarValue> {
108        if args.args.is_empty() {
109            return exec_err!("shuffle expects at least 1 argument");
110        }
111        if args.args.len() > 2 {
112            return exec_err!("shuffle expects at most 2 arguments");
113        }
114
115        // Extract seed from second argument if present
116        let seed = if args.args.len() == 2 {
117            extract_seed(&args.args[1])?
118        } else {
119            None
120        };
121
122        // Convert arguments to arrays
123        let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
124        array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
125    }
126}
127
128/// Extract seed value from ColumnarValue
129fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
130    match seed_arg {
131        ColumnarValue::Scalar(scalar) => {
132            let seed = match scalar {
133                ScalarValue::Int64(Some(v)) => Some(*v as u64),
134                ScalarValue::Null => None,
135                _ => {
136                    return exec_err!(
137                        "shuffle seed must be Int64 type, got '{}'",
138                        scalar.data_type()
139                    );
140                }
141            };
142            Ok(seed)
143        }
144        ColumnarValue::Array(_) => {
145            exec_err!("shuffle seed must be a scalar value, not an array")
146        }
147    }
148}
149
150/// array_shuffle SQL function with optional seed
151fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
152    let [input_array] = take_function_args("shuffle", arg)?;
153    match &input_array.data_type() {
154        List(field) => {
155            let array = as_list_array(input_array)?;
156            general_array_shuffle::<i32>(array, field, seed)
157        }
158        LargeList(field) => {
159            let array = as_large_list_array(input_array)?;
160            general_array_shuffle::<i64>(array, field, seed)
161        }
162        FixedSizeList(field, _) => {
163            let array = as_fixed_size_list_array(input_array)?;
164            fixed_size_array_shuffle(array, field, seed)
165        }
166        Null => Ok(Arc::clone(input_array)),
167        array_type => exec_err!("shuffle does not support type '{array_type}'."),
168    }
169}
170
171fn general_array_shuffle<O: OffsetSizeTrait>(
172    array: &GenericListArray<O>,
173    field: &FieldRef,
174    seed: Option<u64>,
175) -> Result<ArrayRef> {
176    let values = array.values();
177    let original_data = values.to_data();
178    let capacity = Capacities::Array(original_data.len());
179    let mut offsets = vec![O::usize_as(0)];
180    let mut nulls = vec![];
181    let mut mutable =
182        MutableArrayData::with_capacities(vec![&original_data], false, capacity);
183    let mut rng = if let Some(s) = seed {
184        StdRng::seed_from_u64(s)
185    } else {
186        // Use a random seed from the thread-local RNG
187        let seed = rng().random::<u64>();
188        StdRng::seed_from_u64(seed)
189    };
190
191    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
192        // skip the null value
193        if array.is_null(row_index) {
194            nulls.push(false);
195            offsets.push(offsets[row_index] + O::one());
196            mutable.extend(0, 0, 1);
197            continue;
198        }
199        nulls.push(true);
200        let start = offset_window[0];
201        let end = offset_window[1];
202        let length = (end - start).to_usize().unwrap();
203
204        // Create indices and shuffle them
205        let mut indices: Vec<usize> =
206            (start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
207        indices.shuffle(&mut rng);
208
209        // Add shuffled elements
210        for &index in &indices {
211            mutable.extend(0, index, index + 1);
212        }
213
214        offsets.push(offsets[row_index] + O::usize_as(length));
215    }
216
217    let data = mutable.freeze();
218    Ok(Arc::new(GenericListArray::<O>::try_new(
219        Arc::clone(field),
220        OffsetBuffer::<O>::new(offsets.into()),
221        arrow::array::make_array(data),
222        Some(nulls.into()),
223    )?))
224}
225
226fn fixed_size_array_shuffle(
227    array: &FixedSizeListArray,
228    field: &FieldRef,
229    seed: Option<u64>,
230) -> Result<ArrayRef> {
231    let values = array.values();
232    let original_data = values.to_data();
233    let capacity = Capacities::Array(original_data.len());
234    let mut nulls = vec![];
235    let mut mutable =
236        MutableArrayData::with_capacities(vec![&original_data], false, capacity);
237    let value_length = array.value_length() as usize;
238    let mut rng = if let Some(s) = seed {
239        StdRng::seed_from_u64(s)
240    } else {
241        // Use a random seed from the thread-local RNG
242        let seed = rng().random::<u64>();
243        StdRng::seed_from_u64(seed)
244    };
245
246    for row_index in 0..array.len() {
247        // skip the null value
248        if array.is_null(row_index) {
249            nulls.push(false);
250            mutable.extend(0, 0, value_length);
251            continue;
252        }
253        nulls.push(true);
254
255        let start = row_index * value_length;
256        let end = start + value_length;
257
258        // Create indices and shuffle them
259        let mut indices: Vec<usize> = (start..end).collect();
260        indices.shuffle(&mut rng);
261
262        // Add shuffled elements
263        for &index in &indices {
264            mutable.extend(0, index, index + 1);
265        }
266    }
267
268    let data = mutable.freeze();
269    Ok(Arc::new(FixedSizeListArray::try_new(
270        Arc::clone(field),
271        array.value_length(),
272        arrow::array::make_array(data),
273        Some(nulls.into()),
274    )?))
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use arrow::datatypes::Field;
281    use datafusion_expr::ReturnFieldArgs;
282
283    #[test]
284    fn test_shuffle_nullability() {
285        let shuffle = SparkShuffle::new();
286
287        // Test with non-nullable array
288        let non_nullable_field = Arc::new(Field::new(
289            "arr",
290            List(Arc::new(Field::new("item", DataType::Int32, true))),
291            false, // not nullable
292        ));
293
294        let result = shuffle
295            .return_field_from_args(ReturnFieldArgs {
296                arg_fields: &[Arc::clone(&non_nullable_field)],
297                scalar_arguments: &[None],
298            })
299            .unwrap();
300
301        // The result should not be nullable (same as input)
302        assert!(!result.is_nullable());
303        assert_eq!(result.data_type(), non_nullable_field.data_type());
304
305        // Test with nullable array
306        let nullable_field = Arc::new(Field::new(
307            "arr",
308            List(Arc::new(Field::new("item", DataType::Int32, true))),
309            true, // nullable
310        ));
311
312        let result = shuffle
313            .return_field_from_args(ReturnFieldArgs {
314                arg_fields: &[Arc::clone(&nullable_field)],
315                scalar_arguments: &[None],
316            })
317            .unwrap();
318
319        // The result should be nullable (same as input)
320        assert!(result.is_nullable());
321        assert_eq!(result.data_type(), nullable_field.data_type());
322    }
323}