use std::any::Any;
use std::sync::Arc;
use arrow::array::{Array, ArrayRef, NullBufferBuilder, StringBuilder, StructArray};
use arrow::datatypes::{DataType, Field, FieldRef, Fields};
use datafusion_common::cast::as_string_array;
use datafusion_common::{Result, exec_err, internal_err};
use datafusion_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct JsonTuple {
signature: Signature,
}
impl Default for JsonTuple {
fn default() -> Self {
Self::new()
}
}
impl JsonTuple {
pub fn new() -> Self {
Self {
signature: Signature::variadic(vec![DataType::Utf8], Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for JsonTuple {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"json_tuple"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!("return_field_from_args should be used instead")
}
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
if args.arg_fields.len() < 2 {
return exec_err!(
"json_tuple requires at least 2 arguments (json_string, field1), got {}",
args.arg_fields.len()
);
}
let num_fields = args.arg_fields.len() - 1;
let fields: Fields = (0..num_fields)
.map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
.collect::<Vec<_>>()
.into();
Ok(Arc::new(Field::new(
self.name(),
DataType::Struct(fields),
true,
)))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs {
args: arg_values,
return_field,
..
} = args;
let arrays = ColumnarValue::values_to_arrays(&arg_values)?;
let result = json_tuple_inner(&arrays, return_field.data_type())?;
Ok(ColumnarValue::Array(result))
}
}
fn json_tuple_inner(args: &[ArrayRef], return_type: &DataType) -> Result<ArrayRef> {
let num_rows = args[0].len();
let num_fields = args.len() - 1;
let json_array = as_string_array(&args[0])?;
let field_arrays = args[1..]
.iter()
.map(|arg| as_string_array(arg))
.collect::<Result<Vec<_>>>()?;
let mut builders: Vec<StringBuilder> =
(0..num_fields).map(|_| StringBuilder::new()).collect();
let mut null_buffer = NullBufferBuilder::new(num_rows);
for row_idx in 0..num_rows {
if json_array.is_null(row_idx) {
for builder in &mut builders {
builder.append_null();
}
null_buffer.append_null();
continue;
}
let json_str = json_array.value(row_idx);
match serde_json::from_str::<serde_json::Value>(json_str) {
Ok(serde_json::Value::Object(map)) => {
null_buffer.append_non_null();
for (field_idx, builder) in builders.iter_mut().enumerate() {
if field_arrays[field_idx].is_null(row_idx) {
builder.append_null();
continue;
}
let field_name = field_arrays[field_idx].value(row_idx);
match map.get(field_name) {
Some(serde_json::Value::Null) => {
builder.append_null();
}
Some(serde_json::Value::String(s)) => {
builder.append_value(s);
}
Some(other) => {
builder.append_value(other.to_string());
}
None => {
builder.append_null();
}
}
}
}
_ => {
for builder in &mut builders {
builder.append_null();
}
null_buffer.append_null();
}
}
}
let struct_fields = match return_type {
DataType::Struct(fields) => fields.clone(),
_ => {
return internal_err!(
"json_tuple requires a Struct return type, got {:?}",
return_type
);
}
};
let arrays: Vec<ArrayRef> = builders
.into_iter()
.map(|mut builder| Arc::new(builder.finish()) as ArrayRef)
.collect();
let struct_array = StructArray::try_new(struct_fields, arrays, null_buffer.finish())?;
Ok(Arc::new(struct_array))
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr::ReturnFieldArgs;
#[test]
fn test_return_field_shape() {
let func = JsonTuple::new();
let fields = vec![
Arc::new(Field::new("json", DataType::Utf8, false)),
Arc::new(Field::new("f1", DataType::Utf8, false)),
Arc::new(Field::new("f2", DataType::Utf8, false)),
];
let result = func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &[None, None, None],
})
.unwrap();
match result.data_type() {
DataType::Struct(inner) => {
assert_eq!(inner.len(), 2);
assert_eq!(inner[0].name(), "c0");
assert_eq!(inner[1].name(), "c1");
assert_eq!(inner[0].data_type(), &DataType::Utf8);
assert!(inner[0].is_nullable());
}
other => panic!("Expected Struct, got {other:?}"),
}
}
#[test]
fn test_too_few_args() {
let func = JsonTuple::new();
let fields = vec![Arc::new(Field::new("json", DataType::Utf8, false))];
let result = func.return_field_from_args(ReturnFieldArgs {
arg_fields: &fields,
scalar_arguments: &[None],
});
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("at least 2 arguments")
);
}
}