use std::{any::Any, sync::Arc};
use arrow::array::{Array, ArrayRef, new_null_array};
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::utils::SingleRowListArrayBuilder;
use datafusion_common::{Result, internal_err};
use datafusion_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
use datafusion_functions_nested::make_array::{array_array, coerce_types_inner};
use crate::function::functions_nested_utils::make_scalar_function;
const ARRAY_FIELD_DEFAULT_NAME: &str = "element";
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkArray {
signature: Signature,
}
impl Default for SparkArray {
fn default() -> Self {
Self::new()
}
}
impl SparkArray {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkArray {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"array"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let data_types = args
.arg_fields
.iter()
.map(|f| f.data_type())
.cloned()
.collect::<Vec<_>>();
let mut expr_type = DataType::Null;
for arg_type in &data_types {
if !arg_type.equals_datatype(&DataType::Null) {
expr_type = arg_type.clone();
break;
}
}
let return_type = DataType::List(Arc::new(Field::new(
ARRAY_FIELD_DEFAULT_NAME,
expr_type,
true,
)));
Ok(Arc::new(Field::new(
"this_field_name_is_irrelevant",
return_type,
false,
)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;
make_scalar_function(make_array_inner)(args.as_slice())
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.is_empty() {
Ok(vec![])
} else {
coerce_types_inner(arg_types, self.name())
}
}
}
pub fn make_array_inner(arrays: &[ArrayRef]) -> Result<ArrayRef> {
let mut data_type = DataType::Null;
for arg in arrays {
let arg_data_type = arg.data_type();
if !arg_data_type.equals_datatype(&DataType::Null) {
data_type = arg_data_type.clone();
break;
}
}
match data_type {
DataType::Null => {
let length = arrays.iter().map(|a| a.len()).sum();
let array = new_null_array(&DataType::Null, length);
Ok(Arc::new(
SingleRowListArrayBuilder::new(array)
.with_nullable(true)
.with_field_name(Some(ARRAY_FIELD_DEFAULT_NAME.to_string()))
.build_list_array(),
))
}
_ => array_array::<i32>(arrays, data_type, ARRAY_FIELD_DEFAULT_NAME),
}
}