use std::any::Any;
use std::sync::Arc;
use datafusion::arrow::array::{Array, ArrayRef, StringArray, StringViewArray};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{exec_datafusion_err, exec_err, Result, ScalarValue};
use super::{normalize_variant_struct, scalar_to_string, variant_data_type};
use datafusion::logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use parquet_variant_compute::{VariantArrayBuilder, VariantType};
use parquet_variant_json::JsonToVariant;
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct JsonToVariantUdf {
signature: Signature,
}
impl Default for JsonToVariantUdf {
fn default() -> Self {
Self {
signature: Signature::new(
TypeSignature::Uniform(
1,
vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View],
),
Volatility::Immutable,
),
}
}
}
impl ScalarUDFImpl for JsonToVariantUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_json_to_variant"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(variant_data_type())
}
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Arc<Field>> {
Ok(Arc::new(
Field::new(self.name(), variant_data_type(), true).with_extension_type(VariantType),
))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let arg = args
.args
.first()
.ok_or_else(|| exec_datafusion_err!("json_to_variant expects 1 argument"))?;
match arg {
ColumnarValue::Scalar(scalar_value) => {
let json_str = scalar_to_string(scalar_value)?;
let mut builder = VariantArrayBuilder::new(1);
match json_str {
Some(json_str) => match builder.append_json(json_str.as_str()) {
Ok(()) => {}
Err(_) => builder.append_null(),
},
None => builder.append_null(),
}
let struct_array = normalize_variant_struct(builder.build().into());
Ok(ColumnarValue::Scalar(ScalarValue::Struct(Arc::new(
struct_array,
))))
}
ColumnarValue::Array(arr) => match arr.data_type() {
DataType::Utf8 => Ok(ColumnarValue::Array(from_string_array::<StringArray>(arr)?)),
DataType::Utf8View => Ok(ColumnarValue::Array(
from_string_array::<StringViewArray>(arr)?,
)),
DataType::LargeUtf8 => Ok(ColumnarValue::Array(from_string_array::<
datafusion::arrow::array::LargeStringArray,
>(arr)?)),
_ => exec_err!(
"json_to_variant expects string input, got {}",
arr.data_type()
),
},
}
}
}
fn from_string_array<T>(arr: &ArrayRef) -> Result<ArrayRef>
where
T: Array + 'static,
for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
let arr = arr
.as_any()
.downcast_ref::<T>()
.ok_or_else(|| exec_datafusion_err!("Unable to downcast array as expected string type"))?;
let mut builder = VariantArrayBuilder::new(arr.len());
for v in arr {
match v {
Some(json_str) => match builder.append_json(json_str) {
Ok(()) => {}
Err(_) => builder.append_null(),
},
None => builder.append_null(),
}
}
let variant_array = normalize_variant_struct(builder.build().into());
Ok(Arc::new(variant_array) as ArrayRef)
}
pub fn json_to_variant_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(JsonToVariantUdf::default())
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion::arrow::array::Array;
use parquet_variant::Variant;
use parquet_variant_compute::VariantArray;
#[test]
fn test_json_to_variant_scalar_object() {
let udf = JsonToVariantUdf::default();
let arg_field = Arc::new(Field::new("input", DataType::Utf8, true));
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[arg_field.clone()],
scalar_arguments: &[],
})
.unwrap();
let json_input = ScalarValue::Utf8(Some(r#"{"name": "test", "value": 42}"#.to_string()));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(json_input)],
return_field,
arg_fields: vec![arg_field],
number_rows: 1,
config_options: Default::default(),
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Struct(struct_arr)) => {
let variant_array = VariantArray::try_new(struct_arr.as_ref()).unwrap();
let variant = variant_array.value(0);
if let Variant::Object(obj) = variant {
assert!(obj.get("name").is_some());
assert!(obj.get("value").is_some());
} else {
panic!("Expected Variant::Object, got {:?}", variant);
}
}
_ => panic!("Expected scalar struct result"),
}
}
#[test]
fn test_json_to_variant_scalar_null() {
let udf = JsonToVariantUdf::default();
let arg_field = Arc::new(Field::new("input", DataType::Utf8, true));
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[arg_field.clone()],
scalar_arguments: &[],
})
.unwrap();
let json_input = ScalarValue::Utf8(None);
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(json_input)],
return_field,
arg_fields: vec![arg_field],
number_rows: 1,
config_options: Default::default(),
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Scalar(ScalarValue::Struct(struct_arr)) => {
assert!(struct_arr.is_null(0), "Expected null variant");
}
_ => panic!("Expected scalar struct result"),
}
}
#[test]
fn test_json_to_variant_array() {
let udf = JsonToVariantUdf::default();
let arg_field = Arc::new(Field::new("input", DataType::Utf8, true));
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[arg_field.clone()],
scalar_arguments: &[],
})
.unwrap();
let json_array: ArrayRef = Arc::new(StringArray::from(vec![
Some(r#"{"a": 1}"#),
Some(r#"{"a": 2}"#),
None,
]));
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Array(json_array)],
return_field,
arg_fields: vec![arg_field],
number_rows: 3,
config_options: Default::default(),
};
let result = udf.invoke_with_args(args).unwrap();
match result {
ColumnarValue::Array(arr) => {
assert_eq!(arr.len(), 3);
assert!(!arr.is_null(0));
assert!(!arr.is_null(1));
assert!(arr.is_null(2));
}
_ => panic!("Expected array result"),
}
}
}