datafusion_spark/function/array/
shuffle.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{
19    Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
20    OffsetSizeTrait,
21};
22use arrow::buffer::OffsetBuffer;
23use arrow::datatypes::DataType;
24use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
25use arrow::datatypes::FieldRef;
26use datafusion_common::cast::{
27    as_fixed_size_list_array, as_large_list_array, as_list_array,
28};
29use datafusion_common::{exec_err, utils::take_function_args, Result, ScalarValue};
30use datafusion_expr::{
31    ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl,
32    Signature, TypeSignature, Volatility,
33};
34use rand::rng;
35use rand::rngs::StdRng;
36use rand::{seq::SliceRandom, Rng, SeedableRng};
37use std::any::Any;
38use std::sync::Arc;
39
40#[derive(Debug, PartialEq, Eq, Hash)]
41pub struct SparkShuffle {
42    signature: Signature,
43}
44
45impl Default for SparkShuffle {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl SparkShuffle {
52    pub fn new() -> Self {
53        Self {
54            signature: Signature {
55                type_signature: TypeSignature::OneOf(vec![
56                    // Only array argument
57                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
58                        arguments: vec![ArrayFunctionArgument::Array],
59                        array_coercion: None,
60                    }),
61                    // Array + Index (seed) argument
62                    TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
63                        arguments: vec![
64                            ArrayFunctionArgument::Array,
65                            ArrayFunctionArgument::Index,
66                        ],
67                        array_coercion: None,
68                    }),
69                ]),
70                volatility: Volatility::Volatile,
71                parameter_names: None,
72            },
73        }
74    }
75}
76
77impl ScalarUDFImpl for SparkShuffle {
78    fn as_any(&self) -> &dyn Any {
79        self
80    }
81
82    fn name(&self) -> &str {
83        "shuffle"
84    }
85
86    fn signature(&self) -> &Signature {
87        &self.signature
88    }
89
90    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
91        Ok(arg_types[0].clone())
92    }
93
94    fn invoke_with_args(
95        &self,
96        args: datafusion_expr::ScalarFunctionArgs,
97    ) -> Result<ColumnarValue> {
98        if args.args.is_empty() {
99            return exec_err!("shuffle expects at least 1 argument");
100        }
101        if args.args.len() > 2 {
102            return exec_err!("shuffle expects at most 2 arguments");
103        }
104
105        // Extract seed from second argument if present
106        let seed = if args.args.len() == 2 {
107            extract_seed(&args.args[1])?
108        } else {
109            None
110        };
111
112        // Convert arguments to arrays
113        let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
114        array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
115    }
116}
117
118/// Extract seed value from ColumnarValue
119fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
120    match seed_arg {
121        ColumnarValue::Scalar(scalar) => {
122            let seed = match scalar {
123                ScalarValue::Int64(Some(v)) => Some(*v as u64),
124                ScalarValue::Null => None,
125                _ => {
126                    return exec_err!(
127                        "shuffle seed must be Int64 type, got '{}'",
128                        scalar.data_type()
129                    );
130                }
131            };
132            Ok(seed)
133        }
134        ColumnarValue::Array(_) => {
135            exec_err!("shuffle seed must be a scalar value, not an array")
136        }
137    }
138}
139
140/// array_shuffle SQL function with optional seed
141fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
142    let [input_array] = take_function_args("shuffle", arg)?;
143    match &input_array.data_type() {
144        List(field) => {
145            let array = as_list_array(input_array)?;
146            general_array_shuffle::<i32>(array, field, seed)
147        }
148        LargeList(field) => {
149            let array = as_large_list_array(input_array)?;
150            general_array_shuffle::<i64>(array, field, seed)
151        }
152        FixedSizeList(field, _) => {
153            let array = as_fixed_size_list_array(input_array)?;
154            fixed_size_array_shuffle(array, field, seed)
155        }
156        Null => Ok(Arc::clone(input_array)),
157        array_type => exec_err!("shuffle does not support type '{array_type}'."),
158    }
159}
160
161fn general_array_shuffle<O: OffsetSizeTrait>(
162    array: &GenericListArray<O>,
163    field: &FieldRef,
164    seed: Option<u64>,
165) -> Result<ArrayRef> {
166    let values = array.values();
167    let original_data = values.to_data();
168    let capacity = Capacities::Array(original_data.len());
169    let mut offsets = vec![O::usize_as(0)];
170    let mut nulls = vec![];
171    let mut mutable =
172        MutableArrayData::with_capacities(vec![&original_data], false, capacity);
173    let mut rng = if let Some(s) = seed {
174        StdRng::seed_from_u64(s)
175    } else {
176        // Use a random seed from the thread-local RNG
177        let seed = rng().random::<u64>();
178        StdRng::seed_from_u64(seed)
179    };
180
181    for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
182        // skip the null value
183        if array.is_null(row_index) {
184            nulls.push(false);
185            offsets.push(offsets[row_index] + O::one());
186            mutable.extend(0, 0, 1);
187            continue;
188        }
189        nulls.push(true);
190        let start = offset_window[0];
191        let end = offset_window[1];
192        let length = (end - start).to_usize().unwrap();
193
194        // Create indices and shuffle them
195        let mut indices: Vec<usize> =
196            (start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
197        indices.shuffle(&mut rng);
198
199        // Add shuffled elements
200        for &index in &indices {
201            mutable.extend(0, index, index + 1);
202        }
203
204        offsets.push(offsets[row_index] + O::usize_as(length));
205    }
206
207    let data = mutable.freeze();
208    Ok(Arc::new(GenericListArray::<O>::try_new(
209        Arc::clone(field),
210        OffsetBuffer::<O>::new(offsets.into()),
211        arrow::array::make_array(data),
212        Some(nulls.into()),
213    )?))
214}
215
216fn fixed_size_array_shuffle(
217    array: &FixedSizeListArray,
218    field: &FieldRef,
219    seed: Option<u64>,
220) -> Result<ArrayRef> {
221    let values = array.values();
222    let original_data = values.to_data();
223    let capacity = Capacities::Array(original_data.len());
224    let mut nulls = vec![];
225    let mut mutable =
226        MutableArrayData::with_capacities(vec![&original_data], false, capacity);
227    let value_length = array.value_length() as usize;
228    let mut rng = if let Some(s) = seed {
229        StdRng::seed_from_u64(s)
230    } else {
231        // Use a random seed from the thread-local RNG
232        let seed = rng().random::<u64>();
233        StdRng::seed_from_u64(seed)
234    };
235
236    for row_index in 0..array.len() {
237        // skip the null value
238        if array.is_null(row_index) {
239            nulls.push(false);
240            mutable.extend(0, 0, value_length);
241            continue;
242        }
243        nulls.push(true);
244
245        let start = row_index * value_length;
246        let end = start + value_length;
247
248        // Create indices and shuffle them
249        let mut indices: Vec<usize> = (start..end).collect();
250        indices.shuffle(&mut rng);
251
252        // Add shuffled elements
253        for &index in &indices {
254            mutable.extend(0, index, index + 1);
255        }
256    }
257
258    let data = mutable.freeze();
259    Ok(Arc::new(FixedSizeListArray::try_new(
260        Arc::clone(field),
261        array.value_length(),
262        arrow::array::make_array(data),
263        Some(nulls.into()),
264    )?))
265}