hamelin_datafusion 0.7.5

Translate Hamelin TypedAST to DataFusion LogicalPlans
Documentation
//! Variant to JSON string UDF for DataFusion.
//!
//! Converts a Variant to its JSON string representation.

use std::any::Any;
use std::fmt::Write;
use std::sync::Arc;

use datafusion::arrow::array::{ArrayRef, StringArray, StringBuilder};
use datafusion::arrow::datatypes::DataType;
use datafusion::common::{exec_err, Result, ScalarValue};
use datafusion::logical_expr::{
    ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use parquet_variant::Variant;
use parquet_variant_compute::VariantArray;

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct VariantToJsonUdf {
    signature: Signature,
}

impl Default for VariantToJsonUdf {
    fn default() -> Self {
        Self {
            signature: Signature::any(1, Volatility::Immutable),
        }
    }
}

impl ScalarUDFImpl for VariantToJsonUdf {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn name(&self) -> &str {
        "hamelin_to_json_string"
    }

    fn signature(&self) -> &Signature {
        &self.signature
    }

    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
        Ok(DataType::Utf8)
    }

    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
        let args = args.args;
        if args.len() != 1 {
            return exec_err!(
                "hamelin_to_json_string expects exactly 1 argument, got {}",
                args.len()
            );
        }

        match &args[0] {
            ColumnarValue::Scalar(scalar) => {
                if scalar.is_null() {
                    return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
                }
                let array = scalar.to_array_of_size(1)?;
                let variant_array = VariantArray::try_new(&array).map_err(|e| {
                    datafusion::error::DataFusionError::Execution(format!(
                        "hamelin_to_json_string expects a Variant, got error: {e}"
                    ))
                })?;
                let mut json = String::new();
                write_variant(&mut json, &variant_array.value(0))
                    .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
                Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(json))))
            }
            ColumnarValue::Array(array) => {
                let variant_array = VariantArray::try_new(array).map_err(|e| {
                    datafusion::error::DataFusionError::Execution(format!(
                        "hamelin_to_json_string expects a Variant, got error: {e}"
                    ))
                })?;

                let mut builder = StringBuilder::new();
                let mut scratch = String::new();
                for idx in 0..variant_array.len() {
                    if !variant_array.is_valid(idx) {
                        builder.append_null();
                    } else {
                        scratch.clear();
                        write_variant(&mut scratch, &variant_array.value(idx)).map_err(|e| {
                            datafusion::error::DataFusionError::Execution(e.to_string())
                        })?;
                        builder.append_value(&scratch);
                    }
                }

                let result: StringArray = builder.finish();
                Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
            }
        }
    }
}

pub fn variant_to_json_udf() -> ScalarUDF {
    ScalarUDF::new_from_impl(VariantToJsonUdf::default())
}

/// Write a Variant's JSON representation into an existing buffer, avoiding intermediate
/// allocations for nested structures.
fn write_variant(buf: &mut String, variant: &Variant) -> std::fmt::Result {
    match variant {
        Variant::Null => buf.push_str("null"),
        Variant::BooleanTrue => buf.push_str("true"),
        Variant::BooleanFalse => buf.push_str("false"),
        Variant::Int8(n) => write!(buf, "{n}")?,
        Variant::Int16(n) => write!(buf, "{n}")?,
        Variant::Int32(n) => write!(buf, "{n}")?,
        Variant::Int64(n) => write!(buf, "{n}")?,
        Variant::Float(f) => write_float(buf, *f as f64)?,
        Variant::Double(f) => write_float(buf, *f)?,
        Variant::Decimal4(d) => write_decimal(buf, d.integer() as i128, d.scale())?,
        Variant::Decimal8(d) => write_decimal(buf, d.integer() as i128, d.scale())?,
        Variant::Decimal16(d) => write_decimal(buf, d.integer(), d.scale())?,
        Variant::String(s) => write_json_string(buf, s)?,
        Variant::ShortString(s) => write_json_string(buf, s.as_ref())?,
        Variant::Binary(b) => {
            buf.push('"');
            write_hex_bytes(buf, b.as_ref());
            buf.push('"');
        }
        Variant::Date(d) => write!(buf, "\"{d}\"")?,
        Variant::TimestampMicros(ts) => write!(buf, "\"{}\"", ts.to_rfc3339())?,
        Variant::TimestampNtzMicros(ts) => write!(buf, "\"{ts}\"")?,
        Variant::TimestampNanos(ts) => write!(buf, "\"{}\"", ts.to_rfc3339())?,
        Variant::TimestampNtzNanos(ts) => write!(buf, "\"{ts}\"")?,
        Variant::Time(t) => write!(buf, "\"{t}\"")?,
        Variant::Uuid(u) => write!(buf, "\"{u}\"")?,
        Variant::List(list) => {
            buf.push('[');
            for (i, v) in list.iter().enumerate() {
                if i > 0 {
                    buf.push(',');
                }
                write_variant(buf, &v)?;
            }
            buf.push(']');
        }
        Variant::Object(obj) => {
            buf.push('{');
            for (i, (k, v)) in obj.iter().enumerate() {
                if i > 0 {
                    buf.push(',');
                }
                write_json_string(buf, k)?;
                buf.push(':');
                write_variant(buf, &v)?;
            }
            buf.push('}');
        }
    }
    Ok(())
}

/// Format a decimal value preserving trailing zeros (matching Trino's JSON serialization).
fn write_decimal(buf: &mut String, integer: i128, scale: u8) -> std::fmt::Result {
    if scale == 0 {
        return write!(buf, "{integer}");
    }
    let divisor = 10_i128.pow(scale as u32);
    let sign = if integer < 0 { "-" } else { "" };
    let abs = integer.unsigned_abs();
    let quotient = abs / divisor as u128;
    let remainder = abs % divisor as u128;
    write!(
        buf,
        "{sign}{quotient}.{remainder:0>width$}",
        width = scale as usize
    )
}

#[inline]
fn write_hex_bytes(buf: &mut String, bytes: &[u8]) {
    const HEX: &[u8; 16] = b"0123456789abcdef";
    buf.reserve(bytes.len().saturating_mul(2));
    for &byte in bytes {
        buf.push(HEX[(byte >> 4) as usize] as char);
        buf.push(HEX[(byte & 0x0f) as usize] as char);
    }
}

fn write_float(buf: &mut String, f: f64) -> std::fmt::Result {
    if f.is_nan() || f.is_infinite() {
        buf.push_str("null");
    } else {
        write!(buf, "{f}")?;
    }
    Ok(())
}

fn write_json_string(buf: &mut String, s: &str) -> std::fmt::Result {
    buf.push('"');
    // Fast path: no escaping needed for common ASCII/UTF-8 strings.
    if !s
        .as_bytes()
        .iter()
        .any(|&b| b == b'"' || b == b'\\' || b < 0x20)
    {
        buf.push_str(s);
        buf.push('"');
        return Ok(());
    }

    for c in s.chars() {
        match c {
            '"' => buf.push_str("\\\""),
            '\\' => buf.push_str("\\\\"),
            '\n' => buf.push_str("\\n"),
            '\r' => buf.push_str("\\r"),
            '\t' => buf.push_str("\\t"),
            c if c.is_control() => write!(buf, "\\u{:04x}", c as u32)?,
            c => buf.push(c),
        }
    }
    buf.push('"');
    Ok(())
}