use arrow::datatypes::{DataType, Field};
use datafusion_common::utils::take_function_args;
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::repeat::ArrayRepeat;
use std::any::Any;
use std::sync::Arc;
use crate::function::null_utils::{
NullMaskResolution, apply_null_mask, compute_null_mask,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkArrayRepeat {
signature: Signature,
}
impl Default for SparkArrayRepeat {
fn default() -> Self {
Self::new()
}
}
impl SparkArrayRepeat {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkArrayRepeat {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array_repeat"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::List(Arc::new(Field::new_list_field(
arg_types[0].clone(),
true,
))))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
spark_array_repeat(args)
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
let [first_type, second_type] = take_function_args(self.name(), arg_types)?;
let second = match second_type {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
DataType::Int64
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
DataType::UInt64
}
_ => return exec_err!("count must be an integer type"),
};
Ok(vec![first_type.clone(), second])
}
}
fn spark_array_repeat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs {
args: arg_values,
arg_fields,
number_rows,
return_field,
config_options,
} = args;
let return_type = return_field.data_type().clone();
let null_mask = compute_null_mask(&arg_values, number_rows)?;
if matches!(null_mask, NullMaskResolution::ReturnNull) {
return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?));
}
let array_repeat_func = ArrayRepeat::new();
let func_args = ScalarFunctionArgs {
args: arg_values,
arg_fields,
number_rows,
return_field,
config_options,
};
let result = array_repeat_func.invoke_with_args(func_args)?;
apply_null_mask(result, null_mask, &return_type)
}