use std::sync::Arc;
use arrow::datatypes::{DataType, Field, FieldRef};
use datafusion_common::types::{NativeType, logical_string};
use datafusion_common::{Result, internal_err};
use datafusion_expr::{
Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
Signature, TypeSignatureClass, Volatility,
};
use datafusion_functions::string::ascii::ascii;
use datafusion_functions::utils::make_scalar_function;
use std::any::Any;
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkAscii {
signature: Signature,
}
impl Default for SparkAscii {
fn default() -> Self {
Self::new()
}
}
impl SparkAscii {
pub fn new() -> Self {
let string_coercion = Coercion::new_implicit(
TypeSignatureClass::Native(logical_string()),
vec![TypeSignatureClass::Numeric],
NativeType::String,
);
Self {
signature: Signature::coercible(vec![string_coercion], Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for SparkAscii {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"ascii"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
Ok(Arc::new(Field::new("ascii", DataType::Int32, nullable)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(ascii, vec![])(&args.args)
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr::ReturnFieldArgs;
#[test]
fn test_return_field_nullable_input() {
let ascii_func = SparkAscii::new();
let nullable_field = Arc::new(Field::new("input", DataType::Utf8, true));
let result = ascii_func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[nullable_field],
scalar_arguments: &[],
})
.unwrap();
assert_eq!(result.data_type(), &DataType::Int32);
assert!(
result.is_nullable(),
"Output should be nullable when input is nullable"
);
}
#[test]
fn test_return_field_non_nullable_input() {
let ascii_func = SparkAscii::new();
let non_nullable_field = Arc::new(Field::new("input", DataType::Utf8, false));
let result = ascii_func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[non_nullable_field],
scalar_arguments: &[],
})
.unwrap();
assert_eq!(result.data_type(), &DataType::Int32);
assert!(
!result.is_nullable(),
"Output should not be nullable when input is not nullable"
);
}
}