1use std::sync::Arc;
7
8use arrow::{
9 array::{Array, ArrayRef, AsArray, StringViewArray, StructArray},
10 compute::concat,
11};
12use arrow_schema::{DataType, Field, FieldRef, Fields, extension::ExtensionType};
13use datafusion::{
14 common::{exec_datafusion_err, exec_err},
15 error::{DataFusionError, Result},
16 logical_expr::{
17 ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
18 TypeSignature, Volatility,
19 },
20 scalar::ScalarValue,
21};
22use parquet::variant::VariantPath;
23use parquet::variant::{GetOptions, VariantArray, VariantType, variant_get};
24use parquet_variant_json::VariantToJson;
25
26pub fn try_field_as_variant_array(field: &Field) -> Result<()> {
27 assert!(
28 matches!(field.extension_type(), VariantType),
29 "field does not have extension type VariantType"
30 );
31
32 let variant_type = VariantType;
33 variant_type.supports_data_type(field.data_type())?;
34
35 Ok(())
36}
37
38pub fn _try_field_as_binary(field: &Field) -> Result<()> {
39 match field.data_type() {
40 DataType::Binary | DataType::BinaryView | DataType::LargeBinary => {}
41 unsupported => return exec_err!("expected binary field, got {unsupported} field"),
42 }
43
44 Ok(())
45}
46
47pub fn try_parse_string_columnar(array: &Arc<dyn Array>) -> Result<Vec<Option<&str>>> {
48 if let Some(string_array) = array.as_string_opt::<i32>() {
49 return Ok(string_array.into_iter().collect::<Vec<_>>());
50 }
51
52 if let Some(string_view_array) = array.as_string_view_opt() {
53 return Ok(string_view_array.into_iter().collect::<Vec<_>>());
54 }
55
56 if let Some(large_string_array) = array.as_string_opt::<i64>() {
57 return Ok(large_string_array.into_iter().collect::<Vec<_>>());
58 }
59
60 Err(exec_datafusion_err!("expected string array"))
61}
62
63pub fn try_parse_string_scalar(scalar: &ScalarValue) -> Result<Option<&String>> {
64 let b = match scalar {
65 ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => s,
66 unsupported => {
67 return exec_err!(
68 "expected binary scalar value, got data type: {}",
69 unsupported.data_type()
70 );
71 }
72 };
73
74 Ok(b.as_ref())
75}
76
77fn parse_type_hint(spec: &str) -> Result<DataType> {
78 if let Ok(data_type) = spec.parse::<DataType>() {
79 Ok(data_type)
80 } else {
81 exec_err!("invalid type hint: {spec}")
82 }
83}
84
85fn type_hint_from_scalar(field_name: &str, scalar: &ScalarValue) -> Result<FieldRef> {
86 let type_name = match scalar {
87 ScalarValue::Utf8(Some(value))
88 | ScalarValue::Utf8View(Some(value))
89 | ScalarValue::LargeUtf8(Some(value)) => value.as_str(),
90 other => {
91 return exec_err!(
92 "type hint must be a non-null UTF8 literal, got {}",
93 other.data_type()
94 );
95 }
96 };
97
98 let data_type = parse_type_hint(type_name)?;
99 Ok(Arc::new(Field::new(field_name, data_type, true)))
100}
101
102fn type_hint_from_value(field_name: &str, arg: &ColumnarValue) -> Result<FieldRef> {
103 match arg {
104 ColumnarValue::Scalar(value) => type_hint_from_scalar(field_name, value),
105 ColumnarValue::Array(_) => {
106 exec_err!("type hint argument must be a scalar UTF8 literal")
107 }
108 }
109}
110
111fn build_get_options<'a>(path: VariantPath<'a>, as_type: &Option<FieldRef>) -> GetOptions<'a> {
112 match as_type {
113 Some(field) => GetOptions::new_with_path(path).with_as_type(Some(field.clone())),
114 None => GetOptions::new_with_path(path),
115 }
116}
117
118#[derive(Debug, Hash, PartialEq, Eq)]
120pub struct VariantGetUdf {
121 signature: Signature,
122}
123
124impl Default for VariantGetUdf {
125 fn default() -> Self {
126 Self {
127 signature: Signature::new(
128 TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]),
129 Volatility::Immutable,
130 ),
131 }
132 }
133}
134
135impl ScalarUDFImpl for VariantGetUdf {
136 fn as_any(&self) -> &dyn std::any::Any {
137 self
138 }
139
140 fn name(&self) -> &str {
141 "variant_get"
142 }
143
144 fn signature(&self) -> &Signature {
145 &self.signature
146 }
147
148 fn return_type(&self, _arg_types: &[arrow_schema::DataType]) -> Result<arrow_schema::DataType> {
149 Err(DataFusionError::Internal(
150 "implemented return_field_from_args instead".into(),
151 ))
152 }
153
154 fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<Arc<Field>> {
155 if let Some(maybe_scalar) = args.scalar_arguments.get(2) {
156 let scalar = maybe_scalar.ok_or_else(|| {
157 exec_datafusion_err!("type hint argument to variant_get must be a literal")
158 })?;
159 return type_hint_from_scalar(self.name(), scalar);
160 }
161
162 let data_type = DataType::Struct(Fields::from(vec![
163 Field::new("metadata", DataType::BinaryView, false),
164 Field::new("value", DataType::BinaryView, true),
165 ]));
166
167 Ok(Arc::new(
168 Field::new(self.name(), data_type, true).with_extension_type(VariantType),
169 ))
170 }
171
172 fn invoke_with_args(
173 &self,
174 args: datafusion::logical_expr::ScalarFunctionArgs,
175 ) -> Result<ColumnarValue> {
176 let (variant_arg, variant_path, type_arg) = match args.args.as_slice() {
177 [variant_arg, variant_path] => (variant_arg, variant_path, None),
178 [variant_arg, variant_path, type_arg] => (variant_arg, variant_path, Some(type_arg)),
179 _ => return exec_err!("expected 2 or 3 arguments"),
180 };
181
182 let variant_field = args
183 .arg_fields
184 .first()
185 .ok_or_else(|| exec_datafusion_err!("expected argument field"))?;
186
187 try_field_as_variant_array(variant_field.as_ref())?;
188
189 let type_field = type_arg
190 .map(|arg| type_hint_from_value(self.name(), arg))
191 .transpose()?;
192
193 let out = match (variant_arg, variant_path) {
194 (ColumnarValue::Array(variant_array), ColumnarValue::Scalar(variant_path)) => {
195 let variant_path = try_parse_string_scalar(variant_path)?
196 .map(|s| s.as_str())
197 .unwrap_or_default();
198
199 let res = variant_get(
200 variant_array,
201 build_get_options(VariantPath::from(variant_path), &type_field),
202 )?;
203
204 ColumnarValue::Array(res)
205 }
206 (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(variant_path)) => {
207 let ScalarValue::Struct(variant_array) = scalar_variant else {
208 return exec_err!("expected struct array");
209 };
210
211 let variant_array = Arc::clone(variant_array) as ArrayRef;
212
213 let variant_path = try_parse_string_scalar(variant_path)?
214 .map(|s| s.as_str())
215 .unwrap_or_default();
216
217 let res = variant_get(
218 &variant_array,
219 build_get_options(VariantPath::from(variant_path), &type_field),
220 )?;
221
222 let scalar = ScalarValue::try_from_array(res.as_ref(), 0)?;
223 ColumnarValue::Scalar(scalar)
224 }
225 (ColumnarValue::Array(variant_array), ColumnarValue::Array(variant_paths)) => {
226 if variant_array.len() != variant_paths.len() {
227 return exec_err!(
228 "expected variant_array and variant paths to be of same length"
229 );
230 }
231
232 let variant_paths = try_parse_string_columnar(variant_paths)?;
233 let variant_array = VariantArray::try_new(variant_array.as_ref())?;
234
235 let mut out = Vec::with_capacity(variant_array.len());
236
237 for (i, path) in variant_paths.iter().enumerate() {
238 let v = variant_array.value(i);
239 let singleton_variant_array: StructArray = VariantArray::from_iter([v]).into();
241
242 let arr = Arc::new(singleton_variant_array) as ArrayRef;
243
244 let res = variant_get(
245 &arr,
246 build_get_options(VariantPath::from(path.unwrap_or_default()), &type_field),
247 )?;
248
249 out.push(res);
250 }
251
252 let out_refs: Vec<&dyn Array> = out.iter().map(|a| a.as_ref()).collect();
253 ColumnarValue::Array(concat(&out_refs)?)
254 }
255 (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Array(variant_paths)) => {
256 let ScalarValue::Struct(variant_array) = scalar_variant else {
257 return exec_err!("expected struct array");
258 };
259
260 let variant_array = Arc::clone(variant_array) as ArrayRef;
261 let variant_paths = try_parse_string_columnar(variant_paths)?;
262
263 let mut out = Vec::with_capacity(variant_paths.len());
264
265 for path in variant_paths {
266 let path = path.unwrap_or_default();
267 let res = variant_get(
268 &variant_array,
269 build_get_options(VariantPath::from(path), &type_field),
270 )?;
271
272 out.push(res);
273 }
274
275 let out_refs: Vec<&dyn Array> = out.iter().map(|a| a.as_ref()).collect();
276 ColumnarValue::Array(concat(&out_refs)?)
277 }
278 };
279
280 Ok(out)
281 }
282}
283
284#[derive(Debug, Hash, PartialEq, Eq)]
286pub struct VariantPretty {
287 signature: Signature,
288}
289
290impl Default for VariantPretty {
291 fn default() -> Self {
292 Self {
293 signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
294 }
295 }
296}
297
298impl ScalarUDFImpl for VariantPretty {
299 fn as_any(&self) -> &dyn std::any::Any {
300 self
301 }
302
303 fn name(&self) -> &str {
304 "variant_pretty"
305 }
306
307 fn signature(&self) -> &Signature {
308 &self.signature
309 }
310
311 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
312 Ok(DataType::Utf8View)
313 }
314
315 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
316 let field = args
317 .arg_fields
318 .first()
319 .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
320
321 try_field_as_variant_array(field.as_ref())?;
322
323 let arg = args
324 .args
325 .first()
326 .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
327
328 let out = match arg {
329 ColumnarValue::Scalar(scalar) => {
330 let ScalarValue::Struct(variant_array) = scalar else {
331 return exec_err!("Unsupported data type: {}", scalar.data_type());
332 };
333
334 let variant_array = VariantArray::try_new(variant_array.as_ref())?;
335 let v = variant_array.value(0);
336
337 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!("{:?}", v))))
338 }
339 ColumnarValue::Array(arr) => match arr.data_type() {
340 DataType::Struct(_) => {
341 let variant_array = VariantArray::try_new(arr.as_ref())?;
342
343 let out = variant_array
344 .iter()
345 .map(|v| v.map(|v| format!("{:?}", v)))
346 .collect::<Vec<_>>();
347
348 let out: StringViewArray = out.into();
349
350 ColumnarValue::Array(Arc::new(out))
351 }
352 unsupported => return exec_err!("Invalid data type: {unsupported}"),
353 },
354 };
355
356 Ok(out)
357 }
358}
359
360#[derive(Debug, Hash, PartialEq, Eq)]
366pub struct VariantToJsonUdf {
367 signature: Signature,
368}
369
370impl Default for VariantToJsonUdf {
371 fn default() -> Self {
372 Self {
373 signature: Signature::new(
374 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]),
375 Volatility::Immutable,
376 ),
377 }
378 }
379}
380
381impl ScalarUDFImpl for VariantToJsonUdf {
382 fn as_any(&self) -> &dyn std::any::Any {
383 self
384 }
385
386 fn name(&self) -> &str {
387 "variant_to_json"
388 }
389
390 fn signature(&self) -> &Signature {
391 &self.signature
392 }
393
394 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
395 Ok(DataType::Utf8View)
396 }
397
398 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
399 let field = args
400 .arg_fields
401 .first()
402 .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
403
404 try_field_as_variant_array(field.as_ref())?;
405
406 let arg = args
407 .args
408 .first()
409 .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
410
411 let out = match arg {
412 ColumnarValue::Scalar(scalar) => {
413 let ScalarValue::Struct(variant_array) = scalar else {
414 return exec_err!("Unsupported data type: {}", scalar.data_type());
415 };
416
417 let variant_array = VariantArray::try_new(variant_array.as_ref())?;
418 let v = variant_array.value(0);
419
420 ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v.to_json_string()?)))
421 }
422 ColumnarValue::Array(arr) => match arr.data_type() {
423 DataType::Struct(_) => {
424 let variant_array = VariantArray::try_new(arr.as_ref())?;
425
426 let out: StringViewArray = variant_array
427 .iter()
428 .map(|v| v.map(|v| v.to_json_string()).transpose())
429 .collect::<Result<Vec<_>, _>>()?
430 .into();
431
432 ColumnarValue::Array(Arc::new(out))
433 }
434 unsupported => return exec_err!("Invalid data type: {unsupported}"),
435 },
436 };
437
438 Ok(out)
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use arrow::array::{Array, BinaryViewArray};
445 use arrow_schema::{Field, Fields};
446 use datafusion::logical_expr::{ReturnFieldArgs, ScalarFunctionArgs};
447 use parquet::variant::Variant;
448 use parquet::variant::{VariantArrayBuilder, VariantType};
449 use parquet_variant_json::JsonToVariant;
450
451 use super::*;
452
453 #[test]
454 fn test_get_variant_scalar() {
455 let expected_json = serde_json::json!({
456 "name": "norm",
457 "age": 50,
458 "list": [false, true, ()]
459 });
460
461 let json_str = expected_json.to_string();
462 let mut builder = VariantArrayBuilder::new(1);
463 builder.append_json(json_str.as_str()).unwrap();
464
465 let input = builder.build().into();
466
467 let variant_input = ScalarValue::Struct(Arc::new(input));
468 let path = "name";
469
470 let udf = VariantGetUdf::default();
471
472 let arg_field = Arc::new(
473 Field::new("input", DataType::Struct(Fields::empty()), true)
474 .with_extension_type(VariantType),
475 );
476 let arg_field2 = Arc::new(Field::new("path", DataType::Utf8, true));
477
478 let return_field = udf
479 .return_field_from_args(ReturnFieldArgs {
480 arg_fields: &[arg_field.clone(), arg_field2.clone()],
481 scalar_arguments: &[],
482 })
483 .unwrap();
484
485 let args = ScalarFunctionArgs {
486 args: vec![
487 ColumnarValue::Scalar(variant_input),
488 ColumnarValue::Scalar(ScalarValue::Utf8(Some(path.to_string()))),
489 ],
490 return_field,
491 arg_fields: vec![arg_field],
492 number_rows: Default::default(),
493 config_options: Default::default(),
494 };
495
496 let result = udf.invoke_with_args(args).unwrap();
497
498 let ColumnarValue::Scalar(ScalarValue::Struct(struct_arr)) = result else {
499 panic!("expected ScalarValue struct");
500 };
501
502 assert_eq!(struct_arr.len(), 1);
503
504 let metadata_arr = struct_arr
505 .column(0)
506 .as_any()
507 .downcast_ref::<BinaryViewArray>()
508 .unwrap();
509 let value_arr = struct_arr
510 .column(1)
511 .as_any()
512 .downcast_ref::<BinaryViewArray>()
513 .unwrap();
514
515 let metadata = metadata_arr.value(0);
516 let value = value_arr.value(0);
517
518 let v = Variant::try_new(metadata, value).unwrap();
519
520 assert_eq!(v, Variant::from("norm"))
521 }
522
523 #[test]
524 fn test_get_variant_scalar_typed() {
525 let expected_json = serde_json::json!({
526 "name": "norm",
527 "age": 50,
528 "list": [false, true, ()]
529 });
530
531 let json_str = expected_json.to_string();
532 let mut builder = VariantArrayBuilder::new(1);
533 builder.append_json(json_str.as_str()).unwrap();
534
535 let input = builder.build().into();
536
537 let variant_input = ScalarValue::Struct(Arc::new(input));
538 let path = "name";
539
540 let udf = VariantGetUdf::default();
541
542 let arg_field = Arc::new(
543 Field::new("input", DataType::Struct(Fields::empty()), true)
544 .with_extension_type(VariantType),
545 );
546 let arg_field2 = Arc::new(Field::new("path", DataType::Utf8, true));
547 let arg_field3 = Arc::new(Field::new("type_hint", DataType::Utf8, true));
548
549 let path_scalar = ScalarValue::Utf8(Some(path.to_string()));
550 let type_hint = ScalarValue::Utf8(Some("Utf8".to_string()));
551 let scalar_arguments: [Option<&ScalarValue>; 3] =
552 [None, Some(&path_scalar), Some(&type_hint)];
553
554 let return_field = udf
555 .return_field_from_args(ReturnFieldArgs {
556 arg_fields: &[arg_field.clone(), arg_field2, arg_field3],
557 scalar_arguments: &scalar_arguments,
558 })
559 .unwrap();
560 assert_eq!(return_field.data_type(), &DataType::Utf8);
561
562 let args = ScalarFunctionArgs {
563 args: vec![
564 ColumnarValue::Scalar(variant_input),
565 ColumnarValue::Scalar(path_scalar.clone()),
566 ColumnarValue::Scalar(type_hint.clone()),
567 ],
568 return_field,
569 arg_fields: vec![arg_field],
570 number_rows: Default::default(),
571 config_options: Default::default(),
572 };
573
574 let result = udf.invoke_with_args(args).unwrap();
575
576 let ColumnarValue::Scalar(ScalarValue::Utf8(value)) = result else {
577 panic!("expected Utf8 scalar");
578 };
579
580 assert_eq!(value.as_deref(), Some("norm"));
581 }
582}