use arrow::datatypes::{DataType, Field};
use datafusion_common::arrow::datatypes::FieldRef;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::ReturnFieldArgs;
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions::string::concat::ConcatFunc;
use std::any::Any;
use std::sync::Arc;
use crate::function::null_utils::{
NullMaskResolution, apply_null_mask, compute_null_mask,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkConcat {
signature: Signature,
}
impl Default for SparkConcat {
fn default() -> Self {
Self::new()
}
}
impl SparkConcat {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkConcat {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"concat"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
spark_concat(args)
}
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
Ok(arg_types.to_vec())
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
datafusion_common::internal_err!(
"return_type should not be called for Spark concat"
)
}
fn return_field_from_args(&self, args: ReturnFieldArgs<'_>) -> Result<FieldRef> {
use DataType::*;
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
let mut dt = &Utf8;
for field in args.arg_fields {
let data_type = field.data_type();
if data_type == &Utf8View || (data_type == &LargeUtf8 && dt != &Utf8View) {
dt = data_type;
}
}
Ok(Arc::new(Field::new("concat", dt.clone(), nullable)))
}
}
fn spark_concat(args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs {
args: arg_values,
arg_fields,
number_rows,
return_field,
config_options,
} = args;
if arg_values.is_empty() {
let return_type = return_field.data_type();
return match return_type {
DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
String::new(),
)))),
DataType::LargeUtf8 => Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(
Some(String::new()),
))),
_ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(
Some(String::new()),
))),
};
}
let null_mask = compute_null_mask(&arg_values, number_rows)?;
if matches!(null_mask, NullMaskResolution::ReturnNull) {
let return_type = return_field.data_type();
return match return_type {
DataType::Utf8View => Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(None))),
DataType::LargeUtf8 => {
Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)))
}
_ => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))),
};
}
let concat_func = ConcatFunc::new();
let return_type = return_field.data_type().clone();
let func_args = ScalarFunctionArgs {
args: arg_values,
arg_fields,
number_rows,
return_field,
config_options,
};
let result = concat_func.invoke_with_args(func_args)?;
apply_null_mask(result, null_mask, &return_type)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::function::utils::test::test_scalar_function;
use arrow::array::{Array, StringArray};
use arrow::datatypes::{DataType, Field};
use datafusion_common::Result;
use datafusion_expr::ReturnFieldArgs;
use std::sync::Arc;
#[test]
fn test_concat_basic() -> Result<()> {
test_scalar_function!(
SparkConcat::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
],
Ok(Some("SparkSQL")),
&str,
DataType::Utf8,
StringArray
);
Ok(())
}
#[test]
fn test_concat_with_null() -> Result<()> {
test_scalar_function!(
SparkConcat::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("Spark".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SQL".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
],
Ok(None),
&str,
DataType::Utf8,
StringArray
);
Ok(())
}
#[test]
fn test_spark_concat_return_field_non_nullable() -> Result<()> {
let func = SparkConcat::new();
let fields = vec![
Arc::new(Field::new("a", DataType::Utf8, false)),
Arc::new(Field::new("b", DataType::Utf8, false)),
];
let args = ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &[],
};
let field = func.return_field_from_args(args)?;
assert!(
!field.is_nullable(),
"Expected concat result to be non-nullable when all inputs are non-nullable"
);
Ok(())
}
#[test]
fn test_spark_concat_return_field_nullable() -> Result<()> {
let func = SparkConcat::new();
let fields = vec![
Arc::new(Field::new("a", DataType::Utf8, false)),
Arc::new(Field::new("b", DataType::Utf8, true)),
];
let args = ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &[],
};
let field = func.return_field_from_args(args)?;
assert!(
field.is_nullable(),
"Expected concat result to be nullable when any input is nullable"
);
Ok(())
}
}