Skip to main content

liquid_cache_datafusion/reader/
variant_udf.rs

1//! Credit: this is copied from <https://github.com/datafusion-contrib/datafusion-variant/blob/main/src/variant_get.rs>
2//! Full credit to the original authors.
3//! We need to copy it here because we have different datafusion versions, and can't easily include that crate in our workspace.
4//! But eventually, we will use whatever the official datafusion has to offer.
5
6use std::sync::Arc;
7
8use arrow::{
9    array::{Array, ArrayRef, AsArray, StringViewArray, StructArray},
10    compute::concat,
11};
12use arrow_schema::{DataType, Field, FieldRef, Fields, extension::ExtensionType};
13use datafusion::{
14    common::{exec_datafusion_err, exec_err},
15    error::{DataFusionError, Result},
16    logical_expr::{
17        ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
18        TypeSignature, Volatility,
19    },
20    scalar::ScalarValue,
21};
22use parquet::variant::VariantPath;
23use parquet::variant::{GetOptions, VariantArray, VariantType, variant_get};
24use parquet_variant_json::VariantToJson;
25
26pub fn try_field_as_variant_array(field: &Field) -> Result<()> {
27    assert!(
28        matches!(field.extension_type(), VariantType),
29        "field does not have extension type VariantType"
30    );
31
32    let variant_type = VariantType;
33    variant_type.supports_data_type(field.data_type())?;
34
35    Ok(())
36}
37
38pub fn _try_field_as_binary(field: &Field) -> Result<()> {
39    match field.data_type() {
40        DataType::Binary | DataType::BinaryView | DataType::LargeBinary => {}
41        unsupported => return exec_err!("expected binary field, got {unsupported} field"),
42    }
43
44    Ok(())
45}
46
47pub fn try_parse_string_columnar(array: &Arc<dyn Array>) -> Result<Vec<Option<&str>>> {
48    if let Some(string_array) = array.as_string_opt::<i32>() {
49        return Ok(string_array.into_iter().collect::<Vec<_>>());
50    }
51
52    if let Some(string_view_array) = array.as_string_view_opt() {
53        return Ok(string_view_array.into_iter().collect::<Vec<_>>());
54    }
55
56    if let Some(large_string_array) = array.as_string_opt::<i64>() {
57        return Ok(large_string_array.into_iter().collect::<Vec<_>>());
58    }
59
60    Err(exec_datafusion_err!("expected string array"))
61}
62
63pub fn try_parse_string_scalar(scalar: &ScalarValue) -> Result<Option<&String>> {
64    let b = match scalar {
65        ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => s,
66        unsupported => {
67            return exec_err!(
68                "expected binary scalar value, got data type: {}",
69                unsupported.data_type()
70            );
71        }
72    };
73
74    Ok(b.as_ref())
75}
76
77fn parse_type_hint(spec: &str) -> Result<DataType> {
78    if let Ok(data_type) = spec.parse::<DataType>() {
79        Ok(data_type)
80    } else {
81        exec_err!("invalid type hint: {spec}")
82    }
83}
84
85fn type_hint_from_scalar(field_name: &str, scalar: &ScalarValue) -> Result<FieldRef> {
86    let type_name = match scalar {
87        ScalarValue::Utf8(Some(value))
88        | ScalarValue::Utf8View(Some(value))
89        | ScalarValue::LargeUtf8(Some(value)) => value.as_str(),
90        other => {
91            return exec_err!(
92                "type hint must be a non-null UTF8 literal, got {}",
93                other.data_type()
94            );
95        }
96    };
97
98    let data_type = parse_type_hint(type_name)?;
99    Ok(Arc::new(Field::new(field_name, data_type, true)))
100}
101
102fn type_hint_from_value(field_name: &str, arg: &ColumnarValue) -> Result<FieldRef> {
103    match arg {
104        ColumnarValue::Scalar(value) => type_hint_from_scalar(field_name, value),
105        ColumnarValue::Array(_) => {
106            exec_err!("type hint argument must be a scalar UTF8 literal")
107        }
108    }
109}
110
111fn build_get_options<'a>(path: VariantPath<'a>, as_type: &Option<FieldRef>) -> GetOptions<'a> {
112    match as_type {
113        Some(field) => GetOptions::new_with_path(path).with_as_type(Some(field.clone())),
114        None => GetOptions::new_with_path(path),
115    }
116}
117
118/// UDF for getting a variant from a variant array or scalar.
119#[derive(Debug, Hash, PartialEq, Eq)]
120pub struct VariantGetUdf {
121    signature: Signature,
122}
123
124impl Default for VariantGetUdf {
125    fn default() -> Self {
126        Self {
127            signature: Signature::new(
128                TypeSignature::OneOf(vec![TypeSignature::Any(2), TypeSignature::Any(3)]),
129                Volatility::Immutable,
130            ),
131        }
132    }
133}
134
135impl ScalarUDFImpl for VariantGetUdf {
136    fn as_any(&self) -> &dyn std::any::Any {
137        self
138    }
139
140    fn name(&self) -> &str {
141        "variant_get"
142    }
143
144    fn signature(&self) -> &Signature {
145        &self.signature
146    }
147
148    fn return_type(&self, _arg_types: &[arrow_schema::DataType]) -> Result<arrow_schema::DataType> {
149        Err(DataFusionError::Internal(
150            "implemented return_field_from_args instead".into(),
151        ))
152    }
153
154    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<Arc<Field>> {
155        if let Some(maybe_scalar) = args.scalar_arguments.get(2) {
156            let scalar = maybe_scalar.ok_or_else(|| {
157                exec_datafusion_err!("type hint argument to variant_get must be a literal")
158            })?;
159            return type_hint_from_scalar(self.name(), scalar);
160        }
161
162        let data_type = DataType::Struct(Fields::from(vec![
163            Field::new("metadata", DataType::BinaryView, false),
164            Field::new("value", DataType::BinaryView, true),
165        ]));
166
167        Ok(Arc::new(
168            Field::new(self.name(), data_type, true).with_extension_type(VariantType),
169        ))
170    }
171
172    fn invoke_with_args(
173        &self,
174        args: datafusion::logical_expr::ScalarFunctionArgs,
175    ) -> Result<ColumnarValue> {
176        let (variant_arg, variant_path, type_arg) = match args.args.as_slice() {
177            [variant_arg, variant_path] => (variant_arg, variant_path, None),
178            [variant_arg, variant_path, type_arg] => (variant_arg, variant_path, Some(type_arg)),
179            _ => return exec_err!("expected 2 or 3 arguments"),
180        };
181
182        let variant_field = args
183            .arg_fields
184            .first()
185            .ok_or_else(|| exec_datafusion_err!("expected argument field"))?;
186
187        try_field_as_variant_array(variant_field.as_ref())?;
188
189        let type_field = type_arg
190            .map(|arg| type_hint_from_value(self.name(), arg))
191            .transpose()?;
192
193        let out = match (variant_arg, variant_path) {
194            (ColumnarValue::Array(variant_array), ColumnarValue::Scalar(variant_path)) => {
195                let variant_path = try_parse_string_scalar(variant_path)?
196                    .map(|s| s.as_str())
197                    .unwrap_or_default();
198
199                let res = variant_get(
200                    variant_array,
201                    build_get_options(VariantPath::from(variant_path), &type_field),
202                )?;
203
204                ColumnarValue::Array(res)
205            }
206            (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Scalar(variant_path)) => {
207                let ScalarValue::Struct(variant_array) = scalar_variant else {
208                    return exec_err!("expected struct array");
209                };
210
211                let variant_array = Arc::clone(variant_array) as ArrayRef;
212
213                let variant_path = try_parse_string_scalar(variant_path)?
214                    .map(|s| s.as_str())
215                    .unwrap_or_default();
216
217                let res = variant_get(
218                    &variant_array,
219                    build_get_options(VariantPath::from(variant_path), &type_field),
220                )?;
221
222                let scalar = ScalarValue::try_from_array(res.as_ref(), 0)?;
223                ColumnarValue::Scalar(scalar)
224            }
225            (ColumnarValue::Array(variant_array), ColumnarValue::Array(variant_paths)) => {
226                if variant_array.len() != variant_paths.len() {
227                    return exec_err!(
228                        "expected variant_array and variant paths to be of same length"
229                    );
230                }
231
232                let variant_paths = try_parse_string_columnar(variant_paths)?;
233                let variant_array = VariantArray::try_new(variant_array.as_ref())?;
234
235                let mut out = Vec::with_capacity(variant_array.len());
236
237                for (i, path) in variant_paths.iter().enumerate() {
238                    let v = variant_array.value(i);
239                    // todo: is there a better way to go from Variant -> VariantArray?
240                    let singleton_variant_array: StructArray = VariantArray::from_iter([v]).into();
241
242                    let arr = Arc::new(singleton_variant_array) as ArrayRef;
243
244                    let res = variant_get(
245                        &arr,
246                        build_get_options(VariantPath::from(path.unwrap_or_default()), &type_field),
247                    )?;
248
249                    out.push(res);
250                }
251
252                let out_refs: Vec<&dyn Array> = out.iter().map(|a| a.as_ref()).collect();
253                ColumnarValue::Array(concat(&out_refs)?)
254            }
255            (ColumnarValue::Scalar(scalar_variant), ColumnarValue::Array(variant_paths)) => {
256                let ScalarValue::Struct(variant_array) = scalar_variant else {
257                    return exec_err!("expected struct array");
258                };
259
260                let variant_array = Arc::clone(variant_array) as ArrayRef;
261                let variant_paths = try_parse_string_columnar(variant_paths)?;
262
263                let mut out = Vec::with_capacity(variant_paths.len());
264
265                for path in variant_paths {
266                    let path = path.unwrap_or_default();
267                    let res = variant_get(
268                        &variant_array,
269                        build_get_options(VariantPath::from(path), &type_field),
270                    )?;
271
272                    out.push(res);
273                }
274
275                let out_refs: Vec<&dyn Array> = out.iter().map(|a| a.as_ref()).collect();
276                ColumnarValue::Array(concat(&out_refs)?)
277            }
278        };
279
280        Ok(out)
281    }
282}
283
284/// Returns a pretty-printed JSON string from a VariantArray
285#[derive(Debug, Hash, PartialEq, Eq)]
286pub struct VariantPretty {
287    signature: Signature,
288}
289
290impl Default for VariantPretty {
291    fn default() -> Self {
292        Self {
293            signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable),
294        }
295    }
296}
297
298impl ScalarUDFImpl for VariantPretty {
299    fn as_any(&self) -> &dyn std::any::Any {
300        self
301    }
302
303    fn name(&self) -> &str {
304        "variant_pretty"
305    }
306
307    fn signature(&self) -> &Signature {
308        &self.signature
309    }
310
311    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
312        Ok(DataType::Utf8View)
313    }
314
315    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
316        let field = args
317            .arg_fields
318            .first()
319            .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
320
321        try_field_as_variant_array(field.as_ref())?;
322
323        let arg = args
324            .args
325            .first()
326            .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
327
328        let out = match arg {
329            ColumnarValue::Scalar(scalar) => {
330                let ScalarValue::Struct(variant_array) = scalar else {
331                    return exec_err!("Unsupported data type: {}", scalar.data_type());
332                };
333
334                let variant_array = VariantArray::try_new(variant_array.as_ref())?;
335                let v = variant_array.value(0);
336
337                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(format!("{:?}", v))))
338            }
339            ColumnarValue::Array(arr) => match arr.data_type() {
340                DataType::Struct(_) => {
341                    let variant_array = VariantArray::try_new(arr.as_ref())?;
342
343                    let out = variant_array
344                        .iter()
345                        .map(|v| v.map(|v| format!("{:?}", v)))
346                        .collect::<Vec<_>>();
347
348                    let out: StringViewArray = out.into();
349
350                    ColumnarValue::Array(Arc::new(out))
351                }
352                unsupported => return exec_err!("Invalid data type: {unsupported}"),
353            },
354        };
355
356        Ok(out)
357    }
358}
359
360/// Returns a JSON string from a VariantArray
361///
362/// ## Arguments
363/// - expr: a DataType::Struct expression that represents a VariantArray
364/// - options: an optional MAP (note, it seems arrow-rs' parquet-variant is pretty restrictive about the options)
365#[derive(Debug, Hash, PartialEq, Eq)]
366pub struct VariantToJsonUdf {
367    signature: Signature,
368}
369
370impl Default for VariantToJsonUdf {
371    fn default() -> Self {
372        Self {
373            signature: Signature::new(
374                TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]),
375                Volatility::Immutable,
376            ),
377        }
378    }
379}
380
381impl ScalarUDFImpl for VariantToJsonUdf {
382    fn as_any(&self) -> &dyn std::any::Any {
383        self
384    }
385
386    fn name(&self) -> &str {
387        "variant_to_json"
388    }
389
390    fn signature(&self) -> &Signature {
391        &self.signature
392    }
393
394    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
395        Ok(DataType::Utf8View)
396    }
397
398    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
399        let field = args
400            .arg_fields
401            .first()
402            .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
403
404        try_field_as_variant_array(field.as_ref())?;
405
406        let arg = args
407            .args
408            .first()
409            .ok_or_else(|| exec_datafusion_err!("empty argument, expected 1 argument"))?;
410
411        let out = match arg {
412            ColumnarValue::Scalar(scalar) => {
413                let ScalarValue::Struct(variant_array) = scalar else {
414                    return exec_err!("Unsupported data type: {}", scalar.data_type());
415                };
416
417                let variant_array = VariantArray::try_new(variant_array.as_ref())?;
418                let v = variant_array.value(0);
419
420                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v.to_json_string()?)))
421            }
422            ColumnarValue::Array(arr) => match arr.data_type() {
423                DataType::Struct(_) => {
424                    let variant_array = VariantArray::try_new(arr.as_ref())?;
425
426                    let out: StringViewArray = variant_array
427                        .iter()
428                        .map(|v| v.map(|v| v.to_json_string()).transpose())
429                        .collect::<Result<Vec<_>, _>>()?
430                        .into();
431
432                    ColumnarValue::Array(Arc::new(out))
433                }
434                unsupported => return exec_err!("Invalid data type: {unsupported}"),
435            },
436        };
437
438        Ok(out)
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use arrow::array::{Array, BinaryViewArray};
445    use arrow_schema::{Field, Fields};
446    use datafusion::logical_expr::{ReturnFieldArgs, ScalarFunctionArgs};
447    use parquet::variant::Variant;
448    use parquet::variant::{VariantArrayBuilder, VariantType};
449    use parquet_variant_json::JsonToVariant;
450
451    use super::*;
452
453    #[test]
454    fn test_get_variant_scalar() {
455        let expected_json = serde_json::json!({
456            "name": "norm",
457            "age": 50,
458            "list": [false, true, ()]
459        });
460
461        let json_str = expected_json.to_string();
462        let mut builder = VariantArrayBuilder::new(1);
463        builder.append_json(json_str.as_str()).unwrap();
464
465        let input = builder.build().into();
466
467        let variant_input = ScalarValue::Struct(Arc::new(input));
468        let path = "name";
469
470        let udf = VariantGetUdf::default();
471
472        let arg_field = Arc::new(
473            Field::new("input", DataType::Struct(Fields::empty()), true)
474                .with_extension_type(VariantType),
475        );
476        let arg_field2 = Arc::new(Field::new("path", DataType::Utf8, true));
477
478        let return_field = udf
479            .return_field_from_args(ReturnFieldArgs {
480                arg_fields: &[arg_field.clone(), arg_field2.clone()],
481                scalar_arguments: &[],
482            })
483            .unwrap();
484
485        let args = ScalarFunctionArgs {
486            args: vec![
487                ColumnarValue::Scalar(variant_input),
488                ColumnarValue::Scalar(ScalarValue::Utf8(Some(path.to_string()))),
489            ],
490            return_field,
491            arg_fields: vec![arg_field],
492            number_rows: Default::default(),
493            config_options: Default::default(),
494        };
495
496        let result = udf.invoke_with_args(args).unwrap();
497
498        let ColumnarValue::Scalar(ScalarValue::Struct(struct_arr)) = result else {
499            panic!("expected ScalarValue struct");
500        };
501
502        assert_eq!(struct_arr.len(), 1);
503
504        let metadata_arr = struct_arr
505            .column(0)
506            .as_any()
507            .downcast_ref::<BinaryViewArray>()
508            .unwrap();
509        let value_arr = struct_arr
510            .column(1)
511            .as_any()
512            .downcast_ref::<BinaryViewArray>()
513            .unwrap();
514
515        let metadata = metadata_arr.value(0);
516        let value = value_arr.value(0);
517
518        let v = Variant::try_new(metadata, value).unwrap();
519
520        assert_eq!(v, Variant::from("norm"))
521    }
522
523    #[test]
524    fn test_get_variant_scalar_typed() {
525        let expected_json = serde_json::json!({
526            "name": "norm",
527            "age": 50,
528            "list": [false, true, ()]
529        });
530
531        let json_str = expected_json.to_string();
532        let mut builder = VariantArrayBuilder::new(1);
533        builder.append_json(json_str.as_str()).unwrap();
534
535        let input = builder.build().into();
536
537        let variant_input = ScalarValue::Struct(Arc::new(input));
538        let path = "name";
539
540        let udf = VariantGetUdf::default();
541
542        let arg_field = Arc::new(
543            Field::new("input", DataType::Struct(Fields::empty()), true)
544                .with_extension_type(VariantType),
545        );
546        let arg_field2 = Arc::new(Field::new("path", DataType::Utf8, true));
547        let arg_field3 = Arc::new(Field::new("type_hint", DataType::Utf8, true));
548
549        let path_scalar = ScalarValue::Utf8(Some(path.to_string()));
550        let type_hint = ScalarValue::Utf8(Some("Utf8".to_string()));
551        let scalar_arguments: [Option<&ScalarValue>; 3] =
552            [None, Some(&path_scalar), Some(&type_hint)];
553
554        let return_field = udf
555            .return_field_from_args(ReturnFieldArgs {
556                arg_fields: &[arg_field.clone(), arg_field2, arg_field3],
557                scalar_arguments: &scalar_arguments,
558            })
559            .unwrap();
560        assert_eq!(return_field.data_type(), &DataType::Utf8);
561
562        let args = ScalarFunctionArgs {
563            args: vec![
564                ColumnarValue::Scalar(variant_input),
565                ColumnarValue::Scalar(path_scalar.clone()),
566                ColumnarValue::Scalar(type_hint.clone()),
567            ],
568            return_field,
569            arg_fields: vec![arg_field],
570            number_rows: Default::default(),
571            config_options: Default::default(),
572        };
573
574        let result = udf.invoke_with_args(args).unwrap();
575
576        let ColumnarValue::Scalar(ScalarValue::Utf8(value)) = result else {
577            panic!("expected Utf8 scalar");
578        };
579
580        assert_eq!(value.as_deref(), Some("norm"));
581    }
582}