use arrow::array::{Array, ArrayRef, Int64Builder};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::cast::{as_int64_array, as_list_array};
use datafusion_common::utils::ListCoercion;
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, ReturnFieldArgs,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
};
use datafusion_functions_nested::extract::array_slice_udf;
use std::sync::Arc;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkSlice {
signature: Signature,
}
impl Default for SparkSlice {
fn default() -> Self {
Self::new()
}
}
impl SparkSlice {
pub fn new() -> Self {
Self {
signature: Signature {
type_signature: TypeSignature::ArraySignature(
ArrayFunctionSignature::Array {
arguments: vec![
ArrayFunctionArgument::Array,
ArrayFunctionArgument::Index,
ArrayFunctionArgument::Index,
],
array_coercion: Some(ListCoercion::FixedSizedListToList),
},
),
volatility: Volatility::Immutable,
parameter_names: None,
},
}
}
}
impl ScalarUDFImpl for SparkSlice {
fn name(&self) -> &str {
"slice"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
let data_type = match args.arg_fields[0].data_type() {
DataType::Null => {
DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
}
dt => dt.clone(),
};
Ok(Arc::new(Field::new("slice", data_type, nullable)))
}
fn invoke_with_args(
&self,
mut func_args: ScalarFunctionArgs,
) -> Result<ColumnarValue> {
if func_args.args[0].data_type() == DataType::Null {
return Ok(ColumnarValue::Scalar(ScalarValue::new_null_list(
DataType::Null,
true,
1,
)));
}
let array_len = func_args
.args
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
_ => None,
})
.unwrap_or(func_args.number_rows);
let arrays = func_args
.args
.iter()
.map(|arg| match arg {
ColumnarValue::Array(array) => Ok(Arc::clone(array)),
ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len),
})
.collect::<Result<Vec<_>>>()?;
let (start, end) = calculate_start_end(&arrays)?;
array_slice_udf().invoke_with_args(ScalarFunctionArgs {
args: vec![
func_args.args.swap_remove(0),
ColumnarValue::Array(start),
ColumnarValue::Array(end),
],
arg_fields: func_args.arg_fields,
number_rows: func_args.number_rows,
return_field: func_args.return_field,
config_options: func_args.config_options,
})
}
}
fn calculate_start_end(args: &[ArrayRef]) -> Result<(ArrayRef, ArrayRef)> {
let [values, start, length] = take_function_args("slice", args)?;
let values_len = values.len();
let start = as_int64_array(&start)?;
let length = as_int64_array(&length)?;
let values = as_list_array(values)?;
let mut adjusted_start = Int64Builder::with_capacity(values_len);
let mut end = Int64Builder::with_capacity(values_len);
for row in 0..values_len {
if values.is_null(row) || start.is_null(row) || length.is_null(row) {
adjusted_start.append_null();
end.append_null();
continue;
}
let start = start.value(row);
let length = length.value(row);
let value_length = values.value(row).len() as i64;
if start == 0 {
return exec_err!("Start index must not be zero");
}
if length < 0 {
return exec_err!("Length must be non-negative, but got {}", length);
}
let adjusted_start_value = if start < 0 {
start + value_length + 1
} else {
start
};
if adjusted_start_value < 1 {
adjusted_start.append_value(1);
end.append_value(0);
continue;
}
adjusted_start.append_value(adjusted_start_value);
end.append_value(adjusted_start_value + (length - 1));
}
Ok((Arc::new(adjusted_start.finish()), Arc::new(end.finish())))
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::NullArray;
use arrow::datatypes::Field;
use datafusion_common::ScalarValue;
use datafusion_common::cast::as_list_array;
use datafusion_expr::ReturnFieldArgs;
#[test]
fn test_spark_slice_function_when_input_is_null() {
let slice = SparkSlice::new();
let arg_fields: Vec<Arc<Field>> = vec![
Arc::new(Field::new("a", DataType::Null, true)),
Arc::new(Field::new("s", DataType::Int64, true)),
Arc::new(Field::new("l", DataType::Int64, true)),
];
let out = slice
.return_field_from_args(ReturnFieldArgs {
arg_fields: &arg_fields,
scalar_arguments: &[],
})
.unwrap();
assert_eq!(
out.data_type(),
&DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
);
}
#[test]
fn test_spark_slice_function_when_input_array_is_null() {
let input_args = vec![
ColumnarValue::Array(Arc::new(NullArray::new(1))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(3))),
];
let args = ScalarFunctionArgs {
args: input_args,
arg_fields: vec![Arc::new(Field::new("item", DataType::Null, true))],
number_rows: 1,
return_field: Arc::new(Field::new(
"slice",
DataType::List(Arc::new(Field::new_list_field(DataType::Null, true))),
true,
)),
config_options: Arc::new(Default::default()),
};
let slice = SparkSlice::new();
let result = slice.invoke_with_args(args).unwrap();
let arr = result.to_array(1).unwrap();
let list = as_list_array(&arr).unwrap();
assert_eq!(
arr.data_type(),
&DataType::List(Arc::new(Field::new_list_field(DataType::Null, true)))
);
assert!(list.is_null(0));
}
}