use arrow::array::{
Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData,
OffsetSizeTrait,
};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null};
use arrow::datatypes::FieldRef;
use datafusion_common::cast::{
as_fixed_size_list_array, as_large_list_array, as_list_array,
};
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ScalarUDFImpl,
Signature, TypeSignature, Volatility,
};
use rand::rng;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng, seq::SliceRandom};
use std::any::Any;
use std::sync::Arc;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkShuffle {
signature: Signature,
}
impl Default for SparkShuffle {
fn default() -> Self {
Self::new()
}
}
impl SparkShuffle {
pub fn new() -> Self {
Self {
signature: Signature {
type_signature: TypeSignature::OneOf(vec![
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![ArrayFunctionArgument::Array],
array_coercion: None,
}),
TypeSignature::ArraySignature(ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Index,
],
array_coercion: None,
}),
]),
volatility: Volatility::Volatile,
parameter_names: None,
},
}
}
}
impl ScalarUDFImpl for SparkShuffle {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"shuffle"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(
&self,
args: datafusion_expr::ReturnFieldArgs,
) -> Result<FieldRef> {
Ok(Arc::clone(&args.arg_fields[0]))
}
fn invoke_with_args(
&self,
args: datafusion_expr::ScalarFunctionArgs,
) -> Result<ColumnarValue> {
if args.args.is_empty() || args.args.len() > 2 {
return exec_err!("shuffle expects 1 or 2 argument(s)");
}
let seed = if args.args.len() == 2 {
extract_seed(&args.args[1])?
} else {
None
};
let arrays = ColumnarValue::values_to_arrays(&args.args[..1])?;
array_shuffle_with_seed(&arrays, seed).map(ColumnarValue::Array)
}
}
fn extract_seed(seed_arg: &ColumnarValue) -> Result<Option<u64>> {
match seed_arg {
ColumnarValue::Scalar(scalar) => {
let seed = match scalar {
ScalarValue::Int64(Some(v)) => Some(*v as u64),
ScalarValue::Null | ScalarValue::Int64(None) => None,
_ => {
return exec_err!(
"shuffle seed must be Int64 type but got '{}'",
scalar.data_type()
);
}
};
Ok(seed)
}
ColumnarValue::Array(_) => {
exec_err!("shuffle seed must be a scalar value, not an array")
}
}
}
fn array_shuffle_with_seed(arg: &[ArrayRef], seed: Option<u64>) -> Result<ArrayRef> {
let [input_array] = take_function_args("shuffle", arg)?;
match &input_array.data_type() {
List(field) => {
let array = as_list_array(input_array)?;
general_array_shuffle::<i32>(array, field, seed)
}
LargeList(field) => {
let array = as_large_list_array(input_array)?;
general_array_shuffle::<i64>(array, field, seed)
}
FixedSizeList(field, _) => {
let array = as_fixed_size_list_array(input_array)?;
fixed_size_array_shuffle(array, field, seed)
}
Null => Ok(Arc::clone(input_array)),
array_type => exec_err!(
"shuffle does not support type '{array_type}'; \
expected types: List, LargeList, FixedSizeList or Null."
),
}
}
fn general_array_shuffle<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
field: &FieldRef,
seed: Option<u64>,
) -> Result<ArrayRef> {
let values = array.values();
let original_data = values.to_data();
let capacity = Capacities::Array(original_data.len());
let mut offsets = vec![O::usize_as(0)];
let mut nulls = vec![];
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
let mut rng = if let Some(s) = seed {
StdRng::seed_from_u64(s)
} else {
let seed = rng().random::<u64>();
StdRng::seed_from_u64(seed)
};
for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
if array.is_null(row_index) {
nulls.push(false);
offsets.push(offsets[row_index] + O::one());
mutable.extend(0, 0, 1);
continue;
}
nulls.push(true);
let start = offset_window[0];
let end = offset_window[1];
let length = (end - start).to_usize().unwrap();
let mut indices: Vec<usize> =
(start.to_usize().unwrap()..end.to_usize().unwrap()).collect();
indices.shuffle(&mut rng);
for &index in &indices {
mutable.extend(0, index, index + 1);
}
offsets.push(offsets[row_index] + O::usize_as(length));
}
let data = mutable.freeze();
Ok(Arc::new(GenericListArray::<O>::try_new(
Arc::clone(field),
OffsetBuffer::<O>::new(offsets.into()),
arrow::array::make_array(data),
Some(nulls.into()),
)?))
}
fn fixed_size_array_shuffle(
array: &FixedSizeListArray,
field: &FieldRef,
seed: Option<u64>,
) -> Result<ArrayRef> {
let values = array.values();
let original_data = values.to_data();
let capacity = Capacities::Array(original_data.len());
let mut nulls = vec![];
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
let value_length = array.value_length() as usize;
let mut rng = if let Some(s) = seed {
StdRng::seed_from_u64(s)
} else {
let seed = rng().random::<u64>();
StdRng::seed_from_u64(seed)
};
for row_index in 0..array.len() {
if array.is_null(row_index) {
nulls.push(false);
mutable.extend(0, 0, value_length);
continue;
}
nulls.push(true);
let start = row_index * value_length;
let end = start + value_length;
let mut indices: Vec<usize> = (start..end).collect();
indices.shuffle(&mut rng);
for &index in &indices {
mutable.extend(0, index, index + 1);
}
}
let data = mutable.freeze();
Ok(Arc::new(FixedSizeListArray::try_new(
Arc::clone(field),
array.value_length(),
arrow::array::make_array(data),
Some(nulls.into()),
)?))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::Field;
use datafusion_expr::ReturnFieldArgs;
#[test]
fn test_shuffle_nullability() {
let shuffle = SparkShuffle::new();
let non_nullable_field = Arc::new(Field::new(
"arr",
List(Arc::new(Field::new("item", DataType::Int32, true))),
false, ));
let result = shuffle
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&non_nullable_field)],
scalar_arguments: &[None],
})
.unwrap();
assert!(!result.is_nullable());
assert_eq!(result.data_type(), non_nullable_field.data_type());
let nullable_field = Arc::new(Field::new(
"arr",
List(Arc::new(Field::new("item", DataType::Int32, true))),
true, ));
let result = shuffle
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[Arc::clone(&nullable_field)],
scalar_arguments: &[None],
})
.unwrap();
assert!(result.is_nullable());
assert_eq!(result.data_type(), nullable_field.data_type());
}
}