use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{
Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, LargeListArray, LargeStringArray, ListArray, MapArray,
StringArray, StringViewArray, StructArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{Result, ScalarValue};
use super::{normalize_variant_struct, variant_data_type};
use datafusion::logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use datafusion_common::DataFusionError;
use parquet_variant::{
ObjectFieldBuilder, Variant, VariantBuilder, VariantBuilderExt, VariantDecimal16,
};
use parquet_variant_compute::{VariantArray, VariantArrayBuilder, VariantType};
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct CastToVariantUdf {
signature: Signature,
}
impl Default for CastToVariantUdf {
fn default() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for CastToVariantUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_cast_to_variant"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(variant_data_type())
}
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Arc<Field>> {
Ok(Arc::new(
Field::new(self.name(), variant_data_type(), true).with_extension_type(VariantType),
))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
if args.args.len() != 1 {
return Err(DataFusionError::Execution(format!(
"hamelin_cast_to_variant expects 1 argument, got {}",
args.args.len()
)));
}
match &args.args[0] {
ColumnarValue::Scalar(scalar) => {
let array = scalar.to_array_of_size(1)?;
let variant_array = cast_array_to_variant(array.as_ref())?;
let struct_array = normalize_variant_struct(variant_array.into());
Ok(ColumnarValue::Scalar(ScalarValue::Struct(Arc::new(
struct_array,
))))
}
ColumnarValue::Array(array) => {
let variant_array = cast_array_to_variant(array.as_ref())?;
let struct_array = normalize_variant_struct(variant_array.into());
Ok(ColumnarValue::Array(Arc::new(struct_array) as ArrayRef))
}
}
}
}
pub fn cast_to_variant_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(CastToVariantUdf::default())
}
pub fn cast_array_to_variant(array: &dyn Array) -> Result<VariantArray> {
let mut builder = VariantArrayBuilder::new(array.len());
for i in 0..array.len() {
if array.is_null(i) {
builder.append_null();
} else {
append_array_value_to_variant(&mut builder, array, i)?;
}
}
Ok(builder.build())
}
fn append_array_value_to_variant(
builder: &mut VariantArrayBuilder,
array: &dyn Array,
idx: usize,
) -> Result<()> {
if array.is_null(idx) {
builder.append_null();
return Ok(());
}
let mut var_builder = VariantBuilder::new();
build_variant_value(&mut var_builder, array, idx)?;
let (metadata, value) = var_builder.finish();
let variant = Variant::try_new(&metadata, &value)
.map_err(|e| DataFusionError::Execution(format!("Failed to create variant: {e}")))?;
builder.append_value(variant);
Ok(())
}
fn build_variant_value(
builder: &mut impl VariantBuilderExt,
array: &dyn Array,
idx: usize,
) -> Result<()> {
if array.is_null(idx) {
builder.append_value(Variant::Null);
return Ok(());
}
match array.data_type() {
DataType::Null => builder.append_value(Variant::Null),
DataType::Boolean => {
let arr = array
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| DataFusionError::Execution("Expected BooleanArray".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::Int8 => {
let arr = array
.as_any()
.downcast_ref::<Int8Array>()
.ok_or_else(|| DataFusionError::Execution("Expected Int8Array".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::Int16 => {
let arr = array
.as_any()
.downcast_ref::<Int16Array>()
.ok_or_else(|| DataFusionError::Execution("Expected Int16Array".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::Int32 => {
let arr = array
.as_any()
.downcast_ref::<Int32Array>()
.ok_or_else(|| DataFusionError::Execution("Expected Int32Array".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::Int64 => {
let arr = array
.as_any()
.downcast_ref::<Int64Array>()
.ok_or_else(|| DataFusionError::Execution("Expected Int64Array".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::UInt8 => {
let arr = array
.as_any()
.downcast_ref::<UInt8Array>()
.ok_or_else(|| DataFusionError::Execution("Expected UInt8Array".to_string()))?;
builder.append_value(arr.value(idx) as i16);
}
DataType::UInt16 => {
let arr = array
.as_any()
.downcast_ref::<UInt16Array>()
.ok_or_else(|| DataFusionError::Execution("Expected UInt16Array".to_string()))?;
builder.append_value(arr.value(idx) as i32);
}
DataType::UInt32 => {
let arr = array
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| DataFusionError::Execution("Expected UInt32Array".to_string()))?;
builder.append_value(arr.value(idx) as i64);
}
DataType::UInt64 => {
let arr = array
.as_any()
.downcast_ref::<UInt64Array>()
.ok_or_else(|| DataFusionError::Execution("Expected UInt64Array".to_string()))?;
let val = arr.value(idx);
if val <= i64::MAX as u64 {
builder.append_value(val as i64);
} else {
let decimal = VariantDecimal16::try_new(val as i128, 0).map_err(|e| {
DataFusionError::Execution(format!(
"Failed to create variant decimal for u64: {e}"
))
})?;
builder.append_value(decimal);
}
}
DataType::Float32 => {
let arr = array
.as_any()
.downcast_ref::<Float32Array>()
.ok_or_else(|| DataFusionError::Execution("Expected Float32Array".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::Float64 => {
let arr = array
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| DataFusionError::Execution("Expected Float64Array".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::Decimal128(_, scale) => {
let arr = array
.as_any()
.downcast_ref::<Decimal128Array>()
.ok_or_else(|| {
DataFusionError::Execution("Expected Decimal128Array".to_string())
})?;
let variant_scale = u8::try_from(*scale).map_err(|_| {
DataFusionError::Execution(format!("Negative decimal scale not supported: {scale}"))
})?;
let decimal =
VariantDecimal16::try_new(arr.value(idx), variant_scale).map_err(|e| {
DataFusionError::Execution(format!("Failed to create variant decimal: {e}"))
})?;
builder.append_value(decimal);
}
DataType::Utf8 => {
let arr = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| DataFusionError::Execution("Expected StringArray".to_string()))?;
builder.append_value(arr.value(idx));
}
DataType::LargeUtf8 => {
let arr = array
.as_any()
.downcast_ref::<LargeStringArray>()
.ok_or_else(|| {
DataFusionError::Execution("Expected LargeStringArray".to_string())
})?;
builder.append_value(arr.value(idx));
}
DataType::Utf8View => {
let arr = array
.as_any()
.downcast_ref::<StringViewArray>()
.ok_or_else(|| {
DataFusionError::Execution("Expected StringViewArray".to_string())
})?;
builder.append_value(arr.value(idx));
}
DataType::List(_) => {
let arr = array
.as_any()
.downcast_ref::<ListArray>()
.ok_or_else(|| DataFusionError::Execution("Expected ListArray".to_string()))?;
build_list_variant(builder, arr.value(idx).as_ref())?;
}
DataType::LargeList(_) => {
let arr = array
.as_any()
.downcast_ref::<LargeListArray>()
.ok_or_else(|| DataFusionError::Execution("Expected LargeListArray".to_string()))?;
build_list_variant(builder, arr.value(idx).as_ref())?;
}
DataType::Struct(_) => {
let arr = array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| DataFusionError::Execution("Expected StructArray".to_string()))?;
if let Ok(variant_array) = VariantArray::try_new(arr) {
let variant = variant_array.value(idx);
builder.append_value(variant);
return Ok(());
}
build_struct_variant(builder, arr, idx)?;
}
DataType::Map(_, _) => {
let arr = array
.as_any()
.downcast_ref::<MapArray>()
.ok_or_else(|| DataFusionError::Execution("Expected MapArray".to_string()))?;
build_map_variant(builder, arr, idx)?;
}
other => {
return Err(DataFusionError::Execution(format!(
"Unsupported type for variant conversion: {other}"
)))
}
}
Ok(())
}
fn build_list_variant(builder: &mut impl VariantBuilderExt, values: &dyn Array) -> Result<()> {
let mut list_builder = builder.new_list();
for i in 0..values.len() {
build_variant_value(&mut list_builder, values, i)?;
}
list_builder.finish();
Ok(())
}
fn build_struct_variant(
builder: &mut impl VariantBuilderExt,
arr: &StructArray,
idx: usize,
) -> Result<()> {
let mut obj_builder = builder.new_object();
for (field_idx, field) in arr.fields().iter().enumerate() {
let column = arr.column(field_idx);
let field_name = field.name();
if column.is_null(idx) {
obj_builder.insert(field_name, Variant::Null);
} else {
let mut field_builder = ObjectFieldBuilder::new(field_name, &mut obj_builder);
build_variant_value(&mut field_builder, column.as_ref(), idx)?;
}
}
obj_builder.finish();
Ok(())
}
fn build_map_variant(
builder: &mut impl VariantBuilderExt,
arr: &MapArray,
idx: usize,
) -> Result<()> {
let entries = arr.value(idx);
let struct_arr = entries
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| DataFusionError::Execution("Map entries are not StructArray".to_string()))?;
let keys = struct_arr.column(0);
let values = struct_arr.column(1);
let mut obj_builder = builder.new_object();
for i in 0..entries.len() {
let key = get_string_key(keys.as_ref(), i)?;
if values.is_null(i) {
obj_builder.insert(&key, Variant::Null);
} else {
let mut field_builder = ObjectFieldBuilder::new(&key, &mut obj_builder);
build_variant_value(&mut field_builder, values.as_ref(), i)?;
}
}
obj_builder.finish();
Ok(())
}
fn get_string_key(array: &dyn Array, idx: usize) -> Result<String> {
if array.is_null(idx) {
return Err(DataFusionError::Execution(
"Null map keys are not supported in variant conversion".to_string(),
));
}
match array.data_type() {
DataType::Utf8 => {
let arr = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| {
DataFusionError::Execution("Expected StringArray for map key".to_string())
})?;
Ok(arr.value(idx).to_string())
}
DataType::LargeUtf8 => {
let arr = array
.as_any()
.downcast_ref::<LargeStringArray>()
.ok_or_else(|| {
DataFusionError::Execution("Expected LargeStringArray for map key".to_string())
})?;
Ok(arr.value(idx).to_string())
}
DataType::Utf8View => {
let arr = array
.as_any()
.downcast_ref::<StringViewArray>()
.ok_or_else(|| {
DataFusionError::Execution("Expected StringViewArray for map key".to_string())
})?;
Ok(arr.value(idx).to_string())
}
DataType::Int32 => {
let arr = array.as_any().downcast_ref::<Int32Array>().ok_or_else(|| {
DataFusionError::Execution("Expected Int32Array for map key".to_string())
})?;
Ok(arr.value(idx).to_string())
}
DataType::Int64 => {
let arr = array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
DataFusionError::Execution("Expected Int64Array for map key".to_string())
})?;
Ok(arr.value(idx).to_string())
}
other => Err(DataFusionError::Execution(format!(
"Unsupported map key type: {other}"
))),
}
}