Skip to main content

datafusion_spark/function/array/
shuffle.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
20    OffsetSizeTrait,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType;
24use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
25use arrow::datatypes::FieldRef;
26use datafusion_common::cast::{
27    as_fixed_size_list_array, as_large_list_array, as_list_array,
28};
29use datafusion_common::{
30    Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
31};
32use datafusion_expr::{
33    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarFunctionArgs,
34    ScalarUDFImpl, Signature, TypeSignature, Volatility,
35};
36use rand::rng;
37use rand::rngs::StdRng;
38use rand::{Rng, SeedableRng, seq::SliceRandom};
39use std::sync::Arc;
40
41#[derive(Debug, PartialEq, Eq, Hash)]
42pub struct SparkShuffle {
43    signature: Signature,
44}
45
46impl Default for SparkShuffle {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl SparkShuffle {
53    pub fn new() -> Self {
54        Self {
55            signature: Signature {
56                type_signature: TypeSignature::OneOf(vec![
57                    // Only array argument
58                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
59                        arguments: vec![ArrayFunctionArgument::Array],
60                        array_coercion: None,
61                    }),
62                    // Array + Index (seed) argument
63                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
64                        arguments: vec![
65                            ArrayFunctionArgument::Array,
66                            ArrayFunctionArgument::Index,
67                        ],
68                        array_coercion: None,
69                    }),
70                ]),
71                volatility: Volatility::Volatile,
72                parameter_names: None,
73            },
74        }
75    }
76}
77
78impl ScalarUDFImpl for SparkShuffle {
79    fn name(&self) -> &str {
80        "shuffle"
81    }
82
83    fn signature(&self) -> &Signature {
84        &self.signature
85    }
86
87    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
88        internal_err!("return_field_from_args should be used instead")
89    }
90
91    fn return_field_from_args(
92        &self,
93        args: datafusion_expr::ReturnFieldArgs,
94    ) -> Result<FieldRef> {
95        // Shuffle returns an array with the same type and nullability as the input
96        Ok(Arc::clone(&args.arg_fields[0]))
97    }
98
99    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
100        if args.args.is_empty() || args.args.len() > 2 {
101            return exec_err!("shuffle expects 1 or 2 argument(s)");
102        }
103
104        // Extract seed from second argument if present
105        let seed = if args.args.len() == 2 {
106            extract_seed(&args.args[1])?
107        } else {
108            None
109        };
110
111        // Convert arguments to arrays
112        let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
113        array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
114    }
115}
116
117/// Extract seed value from ColumnarValue
118fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
119    match seed_arg {
120        ColumnarValue::Scalar(scalar) => {
121            let seed = match scalar {
122                ScalarValue::Int64(Some(v)) => Some(*v as u64),
123                ScalarValue::Null | ScalarValue::Int64(None) => None,
124                _ => {
125                    return exec_err!(
126                        "shuffle seed must be Int64 type but got '{}'",
127                        scalar.data_type()
128                    );
129                }
130            };
131            Ok(seed)
132        }
133        ColumnarValue::Array(_) => {
134            exec_err!("shuffle seed must be a scalar value, not an array")
135        }
136    }
137}
138
139/// array_shuffle SQL function with optional seed
140fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
141    let [input_array] = take_function_args("shuffle", arg)?;
142    match &input_array.data_type() {
143        List(field) => {
144            let array = as_list_array(input_array)?;
145            general_array_shuffle::<i32>(array, field, seed)
146        }
147        LargeList(field) => {
148            let array = as_large_list_array(input_array)?;
149            general_array_shuffle::<i64>(array, field, seed)
150        }
151        FixedSizeList(field, _) => {
152            let array = as_fixed_size_list_array(input_array)?;
153            fixed_size_array_shuffle(array, field, seed)
154        }
155        Null => Ok(Arc::clone(input_array)),
156        array_type => exec_err!(
157            "shuffle does not support type '{array_type}'; \
158        expected types: List, LargeList, FixedSizeList or Null."
159        ),
160    }
161}
162
163fn general_array_shuffle<O: OffsetSizeTrait>(
164    array: &GenericListArray<O>,
165    field: &FieldRef,
166    seed: Option<u64>,
167) -> Result<ArrayRef> {
168    let values = array.values();
169    let original_data = values.to_data();
170    let capacity = Capacities::Array(original_data.len());
171    let mut offsets = vec![O::usize_as(0)];
172    let mut nulls = vec![];
173    let mut mutable =
174        MutableArrayData::with_capacities(vec![&original_data], false, capacity);
175    let mut rng = if let Some(s) = seed {
176        StdRng::seed_from_u64(s)
177    } else {
178        // Use a random seed from the thread-local RNG
179        let seed = rng().random::<u64>();
180        StdRng::seed_from_u64(seed)
181    };
182
183    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
184        // skip the null value
185        if array.is_null(row_index) {
186            nulls.push(false);
187            offsets.push(offsets[row_index] + O::one());
188            mutable.extend(0, 0, 1);
189            continue;
190        }
191        nulls.push(true);
192        let start = offset_window[0];
193        let end = offset_window[1];
194        let length = (end - start).to_usize().unwrap();
195
196        // Create indices and shuffle them
197        let mut indices: Vec<usize> =
198            (start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
199        indices.shuffle(&mut rng);
200
201        // Add shuffled elements
202        for &index in &indices {
203            mutable.extend(0, index, index + 1);
204        }
205
206        offsets.push(offsets[row_index] + O::usize_as(length));
207    }
208
209    let data = mutable.freeze();
210    Ok(Arc::new(GenericListArray::<O>::try_new(
211        Arc::clone(field),
212        OffsetBuffer::<O>::new(offsets.into()),
213        arrow::array::make_array(data),
214        Some(nulls.into()),
215    )?))
216}
217
218fn fixed_size_array_shuffle(
219    array: &FixedSizeListArray,
220    field: &FieldRef,
221    seed: Option<u64>,
222) -> Result<ArrayRef> {
223    let values = array.values();
224    let original_data = values.to_data();
225    let capacity = Capacities::Array(original_data.len());
226    let mut nulls = vec![];
227    let mut mutable =
228        MutableArrayData::with_capacities(vec![&original_data], false, capacity);
229    let value_length = array.value_length() as usize;
230    let mut rng = if let Some(s) = seed {
231        StdRng::seed_from_u64(s)
232    } else {
233        // Use a random seed from the thread-local RNG
234        let seed = rng().random::<u64>();
235        StdRng::seed_from_u64(seed)
236    };
237
238    for row_index in 0..array.len() {
239        // skip the null value
240        if array.is_null(row_index) {
241            nulls.push(false);
242            mutable.extend(0, 0, value_length);
243            continue;
244        }
245        nulls.push(true);
246
247        let start = row_index * value_length;
248        let end = start + value_length;
249
250        // Create indices and shuffle them
251        let mut indices: Vec<usize> = (start..end).collect();
252        indices.shuffle(&mut rng);
253
254        // Add shuffled elements
255        for &index in &indices {
256            mutable.extend(0, index, index + 1);
257        }
258    }
259
260    let data = mutable.freeze();
261    Ok(Arc::new(FixedSizeListArray::try_new(
262        Arc::clone(field),
263        array.value_length(),
264        arrow::array::make_array(data),
265        Some(nulls.into()),
266    )?))
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272    use arrow::datatypes::Field;
273    use datafusion_expr::ReturnFieldArgs;
274
275    #[test]
276    fn test_shuffle_nullability() {
277        let shuffle = SparkShuffle::new();
278
279        // Test with non-nullable array
280        let non_nullable_field = Arc::new(Field::new(
281            "arr",
282            List(Arc::new(Field::new("item", DataType::Int32, true))),
283            false, // not nullable
284        ));
285
286        let result = shuffle
287            .return_field_from_args(ReturnFieldArgs {
288                arg_fields: &[Arc::clone(&non_nullable_field)],
289                scalar_arguments: &[None],
290            })
291            .unwrap();
292
293        // The result should not be nullable (same as input)
294        assert!(!result.is_nullable());
295        assert_eq!(result.data_type(), non_nullable_field.data_type());
296
297        // Test with nullable array
298        let nullable_field = Arc::new(Field::new(
299            "arr",
300            List(Arc::new(Field::new("item", DataType::Int32, true))),
301            true, // nullable
302        ));
303
304        let result = shuffle
305            .return_field_from_args(ReturnFieldArgs {
306                arg_fields: &[Arc::clone(&nullable_field)],
307                scalar_arguments: &[None],
308            })
309            .unwrap();
310
311        // The result should be nullable (same as input)
312        assert!(result.is_nullable());
313        assert_eq!(result.data_type(), nullable_field.data_type());
314    }
315}