hamelin_datafusion 0.7.8

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! JSON to Variant UDF for DataFusion.
//!
//! Implements `json_to_variant(json_string)` which parses a JSON string into
//! a native Arrow Variant value.
//!
//! Adapted from datafusion-variant crate.

use std::any::Any;
use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef, StringArray, StringViewArray};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{exec_datafusion_err, exec_err, Result, ScalarValue};

use super::{normalize_variant_struct, scalar_to_string, variant_data_type};
use datafusion::logical_expr::{
    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
    TypeSignature, Volatility,
};
use parquet_variant_compute::{VariantArrayBuilder, VariantType};
use parquet_variant_json::JsonToVariant;

/// UDF that converts a JSON string to a Variant value.
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct JsonToVariantUdf {
    signature: Signature,
}

impl Default for JsonToVariantUdf {
    fn default() -> Self {
        Self {
            signature: Signature::new(
                TypeSignature::Uniform(
                    1,
                    vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View],
                ),
                Volatility::Immutable,
            ),
        }
    }
}

impl ScalarUDFImpl for JsonToVariantUdf {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        "hamelin_json_to_variant"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(variant_data_type())
    }

    fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Arc<Field>> {
        Ok(Arc::new(
            Field::new(self.name(), variant_data_type(), true).with_extension_type(VariantType),
        ))
    }

    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        let arg = args
            .args
            .first()
            .ok_or_else(|| exec_datafusion_err!("json_to_variant expects 1 argument"))?;

        match arg {
            ColumnarValue::Scalar(scalar_value) => {
                let json_str = scalar_to_string(scalar_value)?;

                let mut builder = VariantArrayBuilder::new(1);

                match json_str {
                    Some(json_str) => match builder.append_json(json_str.as_str()) {
                        Ok(()) => {}
                        Err(_) => builder.append_null(),
                    },
                    None => builder.append_null(),
                }

                let struct_array = normalize_variant_struct(builder.build().into());
                Ok(ColumnarValue::Scalar(ScalarValue::Struct(Arc::new(
                    struct_array,
                ))))
            }
            ColumnarValue::Array(arr) => match arr.data_type() {
                DataType::Utf8 => Ok(ColumnarValue::Array(from_string_array::<StringArray>(arr)?)),
                DataType::Utf8View => Ok(ColumnarValue::Array(
                    from_string_array::<StringViewArray>(arr)?,
                )),
                DataType::LargeUtf8 => Ok(ColumnarValue::Array(from_string_array::<
                    datafusion::arrow::array::LargeStringArray,
                >(arr)?)),
                _ => exec_err!(
                    "json_to_variant expects string input, got {}",
                    arr.data_type()
                ),
            },
        }
    }
}

/// Convert a string array to a VariantArray
fn from_string_array<T>(arr: &ArrayRef) -> Result<ArrayRef>
where
    T: Array + 'static,
    for<'a> &'a T: IntoIterator<Item = Option<&'a str>>,
{
    let arr = arr
        .as_any()
        .downcast_ref::<T>()
        .ok_or_else(|| exec_datafusion_err!("Unable to downcast array as expected string type"))?;

    let mut builder = VariantArrayBuilder::new(arr.len());

    for v in arr {
        match v {
            Some(json_str) => match builder.append_json(json_str) {
                Ok(()) => {}
                Err(_) => builder.append_null(),
            },
            None => builder.append_null(),
        }
    }

    let variant_array = normalize_variant_struct(builder.build().into());
    Ok(Arc::new(variant_array) as ArrayRef)
}

/// Create the json_to_variant UDF
pub fn json_to_variant_udf() -> ScalarUDF {
    ScalarUDF::new_from_impl(JsonToVariantUdf::default())
}

#[cfg(test)]
mod tests {
    use super::*;
    use datafusion::arrow::array::Array;
    use parquet_variant::Variant;
    use parquet_variant_compute::VariantArray;

    #[test]
    fn test_json_to_variant_scalar_object() {
        let udf = JsonToVariantUdf::default();
        let arg_field = Arc::new(Field::new("input", DataType::Utf8, true));
        let return_field = udf
            .return_field_from_args(ReturnFieldArgs {
                arg_fields: &[arg_field.clone()],
                scalar_arguments: &[],
            })
            .unwrap();

        let json_input = ScalarValue::Utf8(Some(r#"{"name": "test", "value": 42}"#.to_string()));

        let args = ScalarFunctionArgs {
            args: vec![ColumnarValue::Scalar(json_input)],
            return_field,
            arg_fields: vec![arg_field],
            number_rows: 1,
            config_options: Default::default(),
        };

        let result = udf.invoke_with_args(args).unwrap();

        match result {
            ColumnarValue::Scalar(ScalarValue::Struct(struct_arr)) => {
                let variant_array = VariantArray::try_new(struct_arr.as_ref()).unwrap();
                let variant = variant_array.value(0);

                // Verify it's an object with the expected fields
                if let Variant::Object(obj) = variant {
                    assert!(obj.get("name").is_some());
                    assert!(obj.get("value").is_some());
                } else {
                    panic!("Expected Variant::Object, got {:?}", variant);
                }
            }
            _ => panic!("Expected scalar struct result"),
        }
    }

    #[test]
    fn test_json_to_variant_scalar_null() {
        let udf = JsonToVariantUdf::default();
        let arg_field = Arc::new(Field::new("input", DataType::Utf8, true));
        let return_field = udf
            .return_field_from_args(ReturnFieldArgs {
                arg_fields: &[arg_field.clone()],
                scalar_arguments: &[],
            })
            .unwrap();

        let json_input = ScalarValue::Utf8(None);

        let args = ScalarFunctionArgs {
            args: vec![ColumnarValue::Scalar(json_input)],
            return_field,
            arg_fields: vec![arg_field],
            number_rows: 1,
            config_options: Default::default(),
        };

        let result = udf.invoke_with_args(args).unwrap();

        match result {
            ColumnarValue::Scalar(ScalarValue::Struct(struct_arr)) => {
                assert!(struct_arr.is_null(0), "Expected null variant");
            }
            _ => panic!("Expected scalar struct result"),
        }
    }

    #[test]
    fn test_json_to_variant_array() {
        let udf = JsonToVariantUdf::default();
        let arg_field = Arc::new(Field::new("input", DataType::Utf8, true));
        let return_field = udf
            .return_field_from_args(ReturnFieldArgs {
                arg_fields: &[arg_field.clone()],
                scalar_arguments: &[],
            })
            .unwrap();

        let json_array: ArrayRef = Arc::new(StringArray::from(vec![
            Some(r#"{"a": 1}"#),
            Some(r#"{"a": 2}"#),
            None,
        ]));

        let args = ScalarFunctionArgs {
            args: vec![ColumnarValue::Array(json_array)],
            return_field,
            arg_fields: vec![arg_field],
            number_rows: 3,
            config_options: Default::default(),
        };

        let result = udf.invoke_with_args(args).unwrap();

        match result {
            ColumnarValue::Array(arr) => {
                assert_eq!(arr.len(), 3);
                assert!(!arr.is_null(0));
                assert!(!arr.is_null(1));
                assert!(arr.is_null(2));
            }
            _ => panic!("Expected array result"),
        }
    }
}