use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{
Array, ArrayRef, AsArray, GenericListArray, ListArray, NullArray, StructArray,
};
use datafusion::arrow::buffer::{NullBuffer, OffsetBuffer};
use datafusion::arrow::compute::cast;
use datafusion::arrow::datatypes::{DataType, Field, Fields};
use datafusion::common::{Result, ScalarValue};
use datafusion::logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use datafusion_common::DataFusionError;
use serde::{Deserialize, Serialize};
use super::from_variant::FromVariantUdf;
use super::normalize_variant_struct;
use super::to_variant::cast_array_to_variant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CastDescriptor {
Identity,
NullToType,
ArrowCast,
ArrayElementCast(Box<CastDescriptor>),
StructExpansion(Vec<(String, DataType, CastDescriptor)>),
ToVariant,
FromVariant(DataType),
RangeElementCast(Box<CastDescriptor>),
}
impl CastDescriptor {
pub fn is_arrow_native(&self) -> bool {
match self {
CastDescriptor::Identity => true,
CastDescriptor::ArrowCast => true,
CastDescriptor::ArrayElementCast(inner) => inner.is_arrow_native(),
CastDescriptor::NullToType => false, CastDescriptor::StructExpansion(_) => false, CastDescriptor::ToVariant => false, CastDescriptor::FromVariant(_) => false, CastDescriptor::RangeElementCast(inner) => inner.is_arrow_native(),
}
}
}
#[derive(Debug)]
pub struct ArrayCastUdf {
signature: Signature,
target_type: DataType,
cast_descriptor: CastDescriptor,
}
impl std::hash::Hash for ArrayCastUdf {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.target_type.hash(state);
}
}
impl PartialEq for ArrayCastUdf {
fn eq(&self, other: &Self) -> bool {
self.target_type == other.target_type
}
}
impl Eq for ArrayCastUdf {}
impl ArrayCastUdf {
pub fn new(target_type: DataType, cast_descriptor: CastDescriptor) -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
target_type,
cast_descriptor,
}
}
pub fn target_type(&self) -> &DataType {
&self.target_type
}
pub fn cast_descriptor(&self) -> &CastDescriptor {
&self.cast_descriptor
}
}
impl ScalarUDFImpl for ArrayCastUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_array_cast"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.target_type.clone())
}
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Arc<Field>> {
Ok(Arc::new(Field::new(
self.name(),
self.target_type.clone(),
true,
)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
if args.args.len() != 1 {
return Err(DataFusionError::Execution(format!(
"hamelin_array_cast expects 1 argument, got {}",
args.args.len()
)));
}
match &args.args[0] {
ColumnarValue::Scalar(scalar) => {
let array = scalar.to_array_of_size(1)?;
let result = apply_cast(&array, &self.target_type, &self.cast_descriptor)?;
let scalar = ScalarValue::try_from_array(&result, 0)?;
Ok(ColumnarValue::Scalar(scalar))
}
ColumnarValue::Array(array) => {
let result = apply_cast(array, &self.target_type, &self.cast_descriptor)?;
Ok(ColumnarValue::Array(result))
}
}
}
}
pub fn array_cast_udf(target_type: DataType, cast_descriptor: CastDescriptor) -> ScalarUDF {
ScalarUDF::new_from_impl(ArrayCastUdf::new(target_type, cast_descriptor))
}
fn apply_cast(
array: &ArrayRef,
target_type: &DataType,
descriptor: &CastDescriptor,
) -> Result<ArrayRef> {
if descriptor.is_arrow_native() {
return Ok(cast(array.as_ref(), target_type)?);
}
match descriptor {
CastDescriptor::Identity => Ok(Arc::clone(array)),
CastDescriptor::NullToType => {
new_null_array(target_type, array.len())
}
CastDescriptor::ArrowCast => Ok(cast(array.as_ref(), target_type)?),
CastDescriptor::ArrayElementCast(inner_cast) => {
apply_array_element_cast(array, target_type, inner_cast)
}
CastDescriptor::StructExpansion(field_casts) => {
apply_struct_expansion(array, target_type, field_casts)
}
CastDescriptor::ToVariant => {
let variant_array = cast_array_to_variant(array.as_ref())?;
let struct_array = normalize_variant_struct(variant_array.into());
Ok(Arc::new(struct_array) as ArrayRef)
}
CastDescriptor::FromVariant(inner_target_type) => {
let udf_impl = FromVariantUdf::new(inner_target_type.clone());
let arg_field = Arc::new(Field::new("input", array.data_type().clone(), true));
let return_field = Arc::new(Field::new("output", inner_target_type.clone(), true));
let args = datafusion::logical_expr::ScalarFunctionArgs {
args: vec![datafusion::logical_expr::ColumnarValue::Array(Arc::clone(
array,
))],
return_field,
arg_fields: vec![arg_field],
number_rows: array.len(),
config_options: Default::default(),
};
match udf_impl.invoke_with_args(args)? {
datafusion::logical_expr::ColumnarValue::Array(arr) => Ok(arr),
datafusion::logical_expr::ColumnarValue::Scalar(scalar) => {
scalar.to_array_of_size(array.len())
}
}
}
CastDescriptor::RangeElementCast(inner_cast) => {
apply_range_element_cast(array, target_type, inner_cast)
}
}
}
fn apply_array_element_cast(
array: &ArrayRef,
target_type: &DataType,
inner_cast: &CastDescriptor,
) -> Result<ArrayRef> {
let target_element_type = match target_type {
DataType::List(field) => field.data_type(),
DataType::LargeList(field) => field.data_type(),
_ => {
return Err(DataFusionError::Execution(format!(
"ArrayElementCast target type must be List, got {:?}",
target_type
)))
}
};
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
apply_list_element_cast(list_array, target_element_type, inner_cast)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
apply_large_list_element_cast(list_array, target_element_type, inner_cast)
}
_ => Err(DataFusionError::Execution(format!(
"ArrayElementCast source must be List or LargeList, got {:?}",
array.data_type()
))),
}
}
fn apply_list_element_cast(
list_array: &ListArray,
target_element_type: &DataType,
inner_cast: &CastDescriptor,
) -> Result<ArrayRef> {
let values = list_array.values();
let cast_values = apply_cast(values, target_element_type, inner_cast)?;
let field = Arc::new(Field::new("item", target_element_type.clone(), true));
let result = ListArray::try_new(
field,
list_array.offsets().clone(),
cast_values,
list_array.nulls().cloned(),
)?;
Ok(Arc::new(result))
}
fn apply_large_list_element_cast(
list_array: &GenericListArray<i64>,
target_element_type: &DataType,
inner_cast: &CastDescriptor,
) -> Result<ArrayRef> {
let values = list_array.values();
let cast_values = apply_cast(values, target_element_type, inner_cast)?;
let field = Arc::new(Field::new("item", target_element_type.clone(), true));
let result = GenericListArray::<i64>::try_new(
field,
list_array.offsets().clone(),
cast_values,
list_array.nulls().cloned(),
)?;
Ok(Arc::new(result))
}
fn apply_struct_expansion(
array: &ArrayRef,
_target_type: &DataType,
field_casts: &[(String, DataType, CastDescriptor)],
) -> Result<ArrayRef> {
let struct_array = array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
DataFusionError::Execution(format!(
"StructExpansion source must be Struct, got {:?}",
array.data_type()
))
})?;
let len = struct_array.len();
let mut new_columns: Vec<ArrayRef> = Vec::with_capacity(field_casts.len());
let mut new_fields: Vec<Arc<Field>> = Vec::with_capacity(field_casts.len());
for (field_name, field_type, cast_desc) in field_casts {
let column = match cast_desc {
CastDescriptor::NullToType => {
new_null_array(field_type, len)?
}
CastDescriptor::Identity => {
let col = struct_array.column_by_name(field_name).ok_or_else(|| {
DataFusionError::Execution(format!(
"Field '{}' not found in source struct",
field_name
))
})?;
cast(col.as_ref(), field_type)?
}
_ => {
let col = struct_array.column_by_name(field_name).ok_or_else(|| {
DataFusionError::Execution(format!(
"Field '{}' not found in source struct",
field_name
))
})?;
apply_cast(col, field_type, cast_desc)?
}
};
new_fields.push(Arc::new(Field::new(field_name, field_type.clone(), true)));
new_columns.push(column);
}
let result = StructArray::try_new(
Fields::from(new_fields),
new_columns,
struct_array.nulls().cloned(),
)?;
Ok(Arc::new(result))
}
fn apply_range_element_cast(
array: &ArrayRef,
target_type: &DataType,
inner_cast: &CastDescriptor,
) -> Result<ArrayRef> {
let struct_array = array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| {
DataFusionError::Execution(format!(
"RangeElementCast source must be Struct (range), got {:?}",
array.data_type()
))
})?;
let target_struct = match target_type {
DataType::Struct(fields) => fields,
_ => {
return Err(DataFusionError::Execution(format!(
"RangeElementCast target type must be Struct (range), got {:?}",
target_type
)))
}
};
let target_element_type = target_struct
.iter()
.find(|f| f.name() == "begin")
.map(|f| f.data_type())
.ok_or_else(|| {
DataFusionError::Execution("Range struct missing 'begin' field".to_string())
})?;
let begin_col = struct_array.column_by_name("begin").ok_or_else(|| {
DataFusionError::Execution("Range struct missing 'begin' field".to_string())
})?;
let end_col = struct_array.column_by_name("end").ok_or_else(|| {
DataFusionError::Execution("Range struct missing 'end' field".to_string())
})?;
let cast_begin = apply_cast(begin_col, target_element_type, inner_cast)?;
let cast_end = apply_cast(end_col, target_element_type, inner_cast)?;
let new_fields = Fields::from(vec![
Field::new("begin", target_element_type.clone(), true),
Field::new("end", target_element_type.clone(), true),
]);
let result = StructArray::try_new(
new_fields,
vec![cast_begin, cast_end],
struct_array.nulls().cloned(),
)?;
Ok(Arc::new(result))
}
fn new_null_array(data_type: &DataType, len: usize) -> Result<ArrayRef> {
Ok(match data_type {
DataType::Null => Arc::new(NullArray::new(len)),
DataType::List(field) => {
let empty_values = new_null_array(field.data_type(), 0)?;
let offsets = OffsetBuffer::new_zeroed(len);
let nulls = Some(NullBuffer::new_null(len));
Arc::new(ListArray::try_new(
Arc::clone(field),
offsets,
empty_values,
nulls,
)?)
}
DataType::LargeList(field) => {
let empty_values = new_null_array(field.data_type(), 0)?;
let offsets = OffsetBuffer::<i64>::new_zeroed(len);
let nulls = Some(NullBuffer::new_null(len));
Arc::new(GenericListArray::<i64>::try_new(
Arc::clone(field),
offsets,
empty_values,
nulls,
)?)
}
DataType::Struct(fields) => {
let columns = fields
.iter()
.map(|f| new_null_array(f.data_type(), len))
.collect::<Result<Vec<_>>>()?;
let nulls = Some(NullBuffer::new_null(len));
Arc::new(StructArray::try_new(fields.clone(), columns, nulls)?)
}
_ => {
let scalar = ScalarValue::try_from(data_type)?;
scalar.to_array_of_size(len)?
}
})
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::array::Int32Array;
#[test]
fn test_struct_expansion_add_null_field() {
let a_values = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let source = StructArray::try_from(vec![("a", a_values)]).unwrap();
let source_ref: ArrayRef = Arc::new(source);
let target_type = DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]));
let field_casts = vec![
("a".to_string(), DataType::Int32, CastDescriptor::Identity),
("b".to_string(), DataType::Utf8, CastDescriptor::NullToType),
];
let result = apply_struct_expansion(&source_ref, &target_type, &field_casts).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.num_columns(), 2);
assert_eq!(struct_result.len(), 3);
let a_col = struct_result.column(0);
let a_arr = a_col.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(a_arr.value(0), 1);
assert_eq!(a_arr.value(1), 2);
assert_eq!(a_arr.value(2), 3);
let b_col = struct_result.column(1);
assert!(b_col.is_null(0));
assert!(b_col.is_null(1));
assert!(b_col.is_null(2));
}
#[test]
fn test_array_of_struct_expansion() {
let a_values = Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as ArrayRef;
let struct_fields = Fields::from(vec![Field::new("a", DataType::Int32, true)]);
let inner_struct =
StructArray::try_new(struct_fields.clone(), vec![a_values], None).unwrap();
let offsets = OffsetBuffer::from_lengths([2, 2]);
let field = Arc::new(Field::new("item", DataType::Struct(struct_fields), true));
let source = ListArray::try_new(field, offsets, Arc::new(inner_struct), None).unwrap();
let source_ref: ArrayRef = Arc::new(source);
let target_struct_type = DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Utf8, true),
]));
let target_type = DataType::List(Arc::new(Field::new("item", target_struct_type, true)));
let inner_cast = CastDescriptor::StructExpansion(vec![
("a".to_string(), DataType::Int32, CastDescriptor::Identity),
("b".to_string(), DataType::Utf8, CastDescriptor::NullToType),
]);
let cast_desc = CastDescriptor::ArrayElementCast(Box::new(inner_cast));
let result = apply_cast(&source_ref, &target_type, &cast_desc).unwrap();
let list_result = result.as_any().downcast_ref::<ListArray>().unwrap();
assert_eq!(list_result.len(), 2);
let first = list_result.value(0);
let first_struct = first.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(first_struct.len(), 2);
assert_eq!(first_struct.num_columns(), 2);
}
#[test]
fn test_struct_expansion_different_source_field() {
let b_values = Arc::new(Int32Array::from(vec![2])) as ArrayRef;
let source = StructArray::try_from(vec![("b", b_values)]).unwrap();
let source_ref: ArrayRef = Arc::new(source);
let target_type = DataType::Struct(Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
]));
let field_casts = vec![
("a".to_string(), DataType::Int32, CastDescriptor::NullToType),
("b".to_string(), DataType::Int32, CastDescriptor::Identity),
];
let result = apply_struct_expansion(&source_ref, &target_type, &field_casts).unwrap();
let struct_result = result.as_any().downcast_ref::<StructArray>().unwrap();
assert_eq!(struct_result.num_columns(), 2);
assert_eq!(struct_result.len(), 1);
let a_col = struct_result.column(0);
assert!(a_col.is_null(0), "Field 'a' should be null");
let b_col = struct_result.column(1);
let b_arr = b_col.as_any().downcast_ref::<Int32Array>().unwrap();
assert_eq!(b_arr.value(0), 2, "Field 'b' should have value 2");
}
}