use std::any::Any;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::sync::Arc;
use datafusion::arrow::array::{Array, ArrayRef, AsArray, StructArray};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::{exec_datafusion_err, exec_err, Result, ScalarValue};
use super::{normalize_variant_struct, variant_fields};
use datafusion::logical_expr::{
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
TypeSignature, Volatility,
};
use parquet_variant::{Variant, VariantPath, VariantPathElement};
use parquet_variant_compute::{
variant_get as parquet_variant_get, GetOptions, VariantArray, VariantArrayBuilder, VariantType,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum PathSegment {
Field(String),
Index(isize),
}
fn resolve_index(index: isize, len: usize) -> Option<usize> {
if index >= 0 {
let i = index as usize;
if i < len {
Some(i)
} else {
None
}
} else {
let positive = len as isize + index;
if positive >= 0 {
Some(positive as usize)
} else {
None
}
}
}
fn parse_path(path: &str) -> Result<Vec<PathSegment>> {
if path.is_empty() {
return Ok(Vec::new());
}
let mut segments = Vec::new();
let mut remaining = path;
if remaining.starts_with('.') {
return exec_err!("hamelin_variant_get: invalid path '{}': leading dot", path);
}
while !remaining.is_empty() {
if remaining.starts_with('[') {
let close_bracket = remaining.find(']').ok_or_else(|| {
exec_datafusion_err!(
"hamelin_variant_get: invalid path '{}': unclosed bracket",
path
)
})?;
let index_str = &remaining[1..close_bracket];
if index_str.is_empty() {
return exec_err!(
"hamelin_variant_get: invalid path '{}': empty brackets",
path
);
}
let index: isize = index_str.parse().map_err(|_| {
exec_datafusion_err!(
"hamelin_variant_get: invalid path '{}': non-integer bracket content '{}'",
path,
index_str
)
})?;
segments.push(PathSegment::Index(index));
remaining = &remaining[close_bracket + 1..];
if remaining.starts_with('.') {
remaining = &remaining[1..];
if remaining.is_empty() {
return exec_err!("hamelin_variant_get: invalid path '{}': trailing dot", path);
}
}
} else {
let end = remaining
.find(|c| c == '.' || c == '[')
.unwrap_or(remaining.len());
let field = &remaining[..end];
if field.is_empty() {
return exec_err!(
"hamelin_variant_get: invalid path '{}': empty field name",
path
);
}
segments.push(PathSegment::Field(field.to_string()));
remaining = &remaining[end..];
if remaining.starts_with('.') {
remaining = &remaining[1..];
if remaining.is_empty() {
return exec_err!("hamelin_variant_get: invalid path '{}': trailing dot", path);
}
}
}
}
Ok(segments)
}
fn to_variant_path(segments: &[PathSegment]) -> Option<VariantPath<'static>> {
let mut out = Vec::with_capacity(segments.len());
for segment in segments {
match segment {
PathSegment::Field(name) => out.push(VariantPathElement::from(name.clone())),
PathSegment::Index(idx) => {
if *idx < 0 {
return None;
}
out.push(VariantPathElement::from(*idx as usize));
}
}
}
Some(out.into_iter().collect())
}
#[derive(Debug, Hash, PartialEq, Eq)]
pub struct HamelinVariantGetUdf {
signature: Signature,
}
impl Default for HamelinVariantGetUdf {
fn default() -> Self {
Self {
signature: Signature::new(TypeSignature::Any(2), Volatility::Immutable),
}
}
}
impl ScalarUDFImpl for HamelinVariantGetUdf {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"hamelin_variant_get"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Struct(variant_fields()))
}
fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result<Arc<Field>> {
Ok(Arc::new(
Field::new(self.name(), DataType::Struct(variant_fields()), true)
.with_extension_type(VariantType),
))
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
if args.args.len() != 2 {
return exec_err!(
"hamelin_variant_get expects exactly 2 arguments, got {}",
args.args.len()
);
}
let (variant_arg, path_arg) = (&args.args[0], &args.args[1]);
match (variant_arg, path_arg) {
(ColumnarValue::Scalar(variant_scalar), ColumnarValue::Scalar(path_scalar)) => {
let path = scalar_to_parsed_path(path_scalar)?;
let result = extract_from_scalar_variant(variant_scalar, path.as_deref())?;
Ok(ColumnarValue::Scalar(result))
}
(ColumnarValue::Array(variant_array), ColumnarValue::Scalar(path_scalar)) => {
let path = scalar_to_parsed_path(path_scalar)?;
let result = extract_from_variant_array(variant_array, path.as_deref())?;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Scalar(variant_scalar), ColumnarValue::Array(path_array)) => {
let paths = array_to_parsed_paths(path_array)?;
let result = extract_scalar_with_path_array(variant_scalar, &paths)?;
Ok(ColumnarValue::Array(result))
}
(ColumnarValue::Array(variant_array), ColumnarValue::Array(path_array)) => {
if variant_array.len() != path_array.len() {
return exec_err!(
"hamelin_variant_get: variant and path arrays must have same length"
);
}
let paths = array_to_parsed_paths(path_array)?;
let result = extract_from_variant_array_with_paths(variant_array, &paths)?;
Ok(ColumnarValue::Array(result))
}
}
}
}
fn scalar_to_parsed_path(scalar: &ScalarValue) -> Result<Option<Vec<PathSegment>>> {
let raw = match scalar {
ScalarValue::Utf8(s) => s.clone(),
ScalarValue::Utf8View(s) => s.clone(),
ScalarValue::LargeUtf8(s) => s.clone(),
ScalarValue::Null => return Ok(None),
other => {
return exec_err!(
"hamelin_variant_get path must be a string, got {}",
other.data_type()
)
}
};
match raw {
Some(s) => Ok(Some(parse_path(&s)?)),
None => Ok(None),
}
}
fn array_to_parsed_paths(array: &ArrayRef) -> Result<Vec<Option<Arc<Vec<PathSegment>>>>> {
let mut cache: HashMap<String, Arc<Vec<PathSegment>>> = HashMap::new();
let iter: Box<dyn Iterator<Item = Option<&str>>> = match array.data_type() {
DataType::Utf8 => Box::new(array.as_string::<i32>().iter()),
DataType::LargeUtf8 => Box::new(array.as_string::<i64>().iter()),
DataType::Utf8View => Box::new(array.as_string_view().iter()),
other => {
return exec_err!(
"hamelin_variant_get path must be a string array, got {}",
other
)
}
};
iter.map(|opt_s| {
opt_s
.map(|s| {
let parsed = match cache.entry(s.to_string()) {
Entry::Occupied(e) => Arc::clone(e.get()),
Entry::Vacant(e) => {
let segments = parse_path(s)?;
let arc = Arc::new(segments);
e.insert(Arc::clone(&arc));
arc
}
};
Ok(parsed)
})
.transpose()
})
.collect()
}
fn extract_from_scalar_variant(
variant_scalar: &ScalarValue,
path: Option<&[PathSegment]>,
) -> Result<ScalarValue> {
let ScalarValue::Struct(struct_arr) = variant_scalar else {
if matches!(variant_scalar, ScalarValue::Null) {
return Ok(null_variant_scalar());
}
return exec_err!(
"hamelin_variant_get expects Variant input, got {}",
variant_scalar.data_type()
);
};
if struct_arr.is_null(0) {
return Ok(null_variant_scalar());
}
let variant_array = VariantArray::try_new(struct_arr.as_ref())
.map_err(|e| exec_datafusion_err!("Failed to create VariantArray: {}", e))?;
let variant = variant_array.value(0);
let extracted = extract_by_path(&variant, path.unwrap_or(&[]));
Ok(variant_to_scalar(extracted))
}
fn extract_from_variant_array(
variant_array: &ArrayRef,
path: Option<&[PathSegment]>,
) -> Result<ArrayRef> {
let segments = path.unwrap_or(&[]);
if segments.is_empty() {
return Ok(variant_array.clone());
}
if let Some(kernel_path) = to_variant_path(segments) {
let options = GetOptions {
path: kernel_path,
as_type: None,
..Default::default()
};
if let Ok(result) = parquet_variant_get(variant_array, options) {
return Ok(result);
}
}
let struct_arr = variant_array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| exec_datafusion_err!("hamelin_variant_get expects Variant array input"))?;
let variant_array = VariantArray::try_new(struct_arr)
.map_err(|e| exec_datafusion_err!("Failed to create VariantArray: {}", e))?;
let mut builder = VariantArrayBuilder::new(variant_array.len());
for i in 0..variant_array.len() {
if !variant_array.is_valid(i) {
builder.append_null();
} else {
let variant = variant_array.value(i);
let extracted = extract_by_path(&variant, segments);
append_variant_to_builder(&mut builder, extracted);
}
}
let result = normalize_variant_struct(builder.build().into());
Ok(Arc::new(result) as ArrayRef)
}
fn extract_scalar_with_path_array(
variant_scalar: &ScalarValue,
paths: &[Option<Arc<Vec<PathSegment>>>],
) -> Result<ArrayRef> {
let ScalarValue::Struct(struct_arr) = variant_scalar else {
if matches!(variant_scalar, ScalarValue::Null) {
let mut builder = VariantArrayBuilder::new(paths.len());
for _ in paths {
builder.append_null();
}
let result = normalize_variant_struct(builder.build().into());
return Ok(Arc::new(result) as ArrayRef);
}
return exec_err!(
"hamelin_variant_get expects Variant input, got {}",
variant_scalar.data_type()
);
};
let variant_array = VariantArray::try_new(struct_arr.as_ref())
.map_err(|e| exec_datafusion_err!("Failed to create VariantArray: {}", e))?;
let variant = if struct_arr.is_null(0) {
None
} else {
Some(variant_array.value(0))
};
let mut builder = VariantArrayBuilder::new(paths.len());
for path in paths {
match (&variant, path) {
(Some(v), Some(segments)) => {
let extracted = extract_by_path(v, segments);
append_variant_to_builder(&mut builder, extracted);
}
_ => builder.append_null(),
}
}
let result = normalize_variant_struct(builder.build().into());
Ok(Arc::new(result) as ArrayRef)
}
fn extract_from_variant_array_with_paths(
variant_array: &ArrayRef,
paths: &[Option<Arc<Vec<PathSegment>>>],
) -> Result<ArrayRef> {
if let Some(first) = paths.first() {
if paths.iter().all(|p| p == first) {
let single = first.as_ref().map(|p| p.as_slice());
return extract_from_variant_array(variant_array, single);
}
}
let struct_arr = variant_array
.as_any()
.downcast_ref::<StructArray>()
.ok_or_else(|| exec_datafusion_err!("hamelin_variant_get expects Variant array input"))?;
let variant_array = VariantArray::try_new(struct_arr)
.map_err(|e| exec_datafusion_err!("Failed to create VariantArray: {}", e))?;
let mut builder = VariantArrayBuilder::new(variant_array.len());
for (i, path) in paths.iter().enumerate() {
if !variant_array.is_valid(i) || path.is_none() {
builder.append_null();
} else {
let variant = variant_array.value(i);
let extracted = extract_by_path(&variant, path.as_ref().map_or(&[], |s| s.as_slice()));
append_variant_to_builder(&mut builder, extracted);
}
}
let result = normalize_variant_struct(builder.build().into());
Ok(Arc::new(result) as ArrayRef)
}
fn extract_by_path<'a>(
variant: &'a Variant<'a, 'a>,
segments: &[PathSegment],
) -> Option<Variant<'a, 'a>> {
let mut current = variant.clone();
for segment in segments {
current = match segment {
PathSegment::Field(name) => match current {
Variant::Object(obj) => obj.get(name)?,
_ => return None,
},
PathSegment::Index(index) => match current {
Variant::List(list) => {
let resolved = resolve_index(*index, list.len())?;
list.get(resolved)?
}
_ => return None,
},
};
}
Some(current)
}
fn variant_to_scalar(variant: Option<Variant>) -> ScalarValue {
match variant {
Some(v) => {
let mut builder = VariantArrayBuilder::new(1);
builder.append_variant(v);
let struct_array = normalize_variant_struct(builder.build().into());
ScalarValue::Struct(Arc::new(struct_array))
}
None => null_variant_scalar(),
}
}
fn append_variant_to_builder(builder: &mut VariantArrayBuilder, variant: Option<Variant>) {
match variant {
Some(v) => builder.append_variant(v),
None => builder.append_null(),
}
}
fn null_variant_scalar() -> ScalarValue {
let mut builder = VariantArrayBuilder::new(1);
builder.append_null();
let struct_array = normalize_variant_struct(builder.build().into());
ScalarValue::Struct(Arc::new(struct_array))
}
pub fn variant_get_udf() -> ScalarUDF {
ScalarUDF::new_from_impl(HamelinVariantGetUdf::default())
}
#[cfg(test)]
mod tests {
use super::*;
use parquet_variant_compute::VariantArrayBuilder;
use parquet_variant_json::JsonToVariant;
use rstest::rstest;
fn variant_scalar_from_json(json: &str) -> ScalarValue {
let mut builder = VariantArrayBuilder::new(1);
builder.append_json(json).unwrap();
ScalarValue::Struct(Arc::new(builder.build().into()))
}
fn invoke_variant_get(variant: ScalarValue, path: ScalarValue) -> Result<ColumnarValue> {
let udf = HamelinVariantGetUdf::default();
let return_field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[],
scalar_arguments: &[],
})
.unwrap();
let args = ScalarFunctionArgs {
args: vec![ColumnarValue::Scalar(variant), ColumnarValue::Scalar(path)],
return_field,
arg_fields: vec![],
number_rows: 1,
config_options: Default::default(),
};
udf.invoke_with_args(args)
}
#[rstest]
#[case("", vec![])]
#[case("foo", vec![PathSegment::Field("foo".into())])]
#[case("foo.bar", vec![PathSegment::Field("foo".into()), PathSegment::Field("bar".into())])]
#[case("[0]", vec![PathSegment::Index(0)])]
#[case("[-1]", vec![PathSegment::Index(-1)])]
#[case("arr[2]", vec![PathSegment::Field("arr".into()), PathSegment::Index(2)])]
#[case(
"users[0].scores[-1]",
vec![
PathSegment::Field("users".into()),
PathSegment::Index(0),
PathSegment::Field("scores".into()),
PathSegment::Index(-1),
]
)]
#[case(
"[0].name",
vec![PathSegment::Index(0), PathSegment::Field("name".into())]
)]
fn test_parse_path_valid(#[case] input: &str, #[case] expected: Vec<PathSegment>) {
assert_eq!(parse_path(input).unwrap(), expected);
}
#[rstest]
#[case("foo[", "unclosed bracket")]
#[case("foo[]", "empty brackets")]
#[case("foo[abc]", "non-integer bracket content")]
#[case("foo..bar", "empty field name")]
#[case("foo.", "trailing dot")]
#[case(".foo", "leading dot")]
#[case("[1].", "trailing dot")]
fn test_parse_path_invalid(#[case] input: &str, #[case] expected_fragment: &str) {
let err = parse_path(input).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains(expected_fragment),
"Error for '{}' should contain '{}', got: {}",
input,
expected_fragment,
msg
);
}
#[rstest]
#[case(0, 3, Some(0))]
#[case(2, 3, Some(2))]
#[case(3, 3, None)]
#[case(0, 0, None)]
#[case(-1, 3, Some(2))]
#[case(-3, 3, Some(0))]
#[case(-4, 3, None)]
#[case(-1, 0, None)]
fn test_resolve_index(
#[case] index: isize,
#[case] len: usize,
#[case] expected: Option<usize>,
) {
assert_eq!(resolve_index(index, len), expected);
}
#[test]
fn test_extract_field() {
let variant = variant_scalar_from_json(r#"{"foo": 1}"#);
let path = ScalarValue::Utf8(Some("foo".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_extract_nested_field() {
let variant = variant_scalar_from_json(r#"{"foo": {"bar": 2}}"#);
let path = ScalarValue::Utf8(Some("foo.bar".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_extract_array_index() {
let variant = variant_scalar_from_json(r#"[1, 2, 3]"#);
let path = ScalarValue::Utf8(Some("[1]".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_extract_field_then_index() {
let variant = variant_scalar_from_json(r#"{"arr": [10, 20, 30]}"#);
let path = ScalarValue::Utf8(Some("arr[2]".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_extract_index_then_field() {
let variant = variant_scalar_from_json(r#"[{"name": "alice"}, {"name": "bob"}]"#);
let path = ScalarValue::Utf8(Some("[1].name".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_extract_complex_path() {
let variant = variant_scalar_from_json(
r#"{"users": [{"name": "alice", "scores": [90, 85]}, {"name": "bob", "scores": [75, 80]}]}"#,
);
let path = ScalarValue::Utf8(Some("users[0].scores[1]".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_extract_missing_field() {
let variant = variant_scalar_from_json(r#"{"foo": 1}"#);
let path = ScalarValue::Utf8(Some("bar".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
if let ColumnarValue::Scalar(ScalarValue::Struct(arr)) = result {
assert!(arr.len() == 1);
} else {
panic!("Expected scalar struct result");
}
}
#[test]
fn test_negative_index_last_element() {
let variant = variant_scalar_from_json(r#"[10, 20, 30]"#);
let path = ScalarValue::Utf8(Some("[-1]".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_negative_index_on_nested_array() {
let variant = variant_scalar_from_json(r#"{"arr": [10, 20, 30]}"#);
let path = ScalarValue::Utf8(Some("arr[-1]".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
assert!(matches!(
result,
ColumnarValue::Scalar(ScalarValue::Struct(_))
));
}
#[test]
fn test_negative_index_out_of_range_returns_null() {
let variant = variant_scalar_from_json(r#"[10, 20, 30]"#);
let path = ScalarValue::Utf8(Some("[-4]".to_string()));
let result = invoke_variant_get(variant, path).unwrap();
if let ColumnarValue::Scalar(ScalarValue::Struct(arr)) = result {
assert!(arr.is_null(0));
} else {
panic!("Expected scalar struct result");
}
}
#[test]
fn test_invalid_path_returns_error() {
let variant = variant_scalar_from_json(r#"{"foo": 1}"#);
let path = ScalarValue::Utf8(Some("foo[]".to_string()));
let result = invoke_variant_get(variant, path);
assert!(result.is_err());
}
#[test]
fn test_unclosed_bracket_returns_error() {
let variant = variant_scalar_from_json(r#"{"foo": [1]}"#);
let path = ScalarValue::Utf8(Some("foo[".to_string()));
let result = invoke_variant_get(variant, path);
assert!(result.is_err());
}
#[test]
fn test_double_dot_returns_error() {
let variant = variant_scalar_from_json(r#"{"foo": {"bar": 1}}"#);
let path = ScalarValue::Utf8(Some("foo..bar".to_string()));
let result = invoke_variant_get(variant, path);
assert!(result.is_err());
}
}