Skip to main content

datafusion_spark/function/json/
json_tuple.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{Array, ArrayRef, NullBufferBuilder, StringBuilder, StructArray};
22use arrow::datatypes::{DataType, Field, FieldRef, Fields};
23use datafusion_common::cast::as_string_array;
24use datafusion_common::{Result, exec_err, internal_err};
25use datafusion_expr::{
26    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27    Volatility,
28};
29
30/// Spark-compatible `json_tuple` expression
31///
32/// <https://spark.apache.org/docs/latest/api/sql/index.html#json_tuple>
33///
34/// Extracts top-level fields from a JSON string and returns them as a struct.
35///
36/// `json_tuple(json_string, field1, field2, ...) -> Struct<c0: Utf8, c1: Utf8, ...>`
37///
38/// Note: In Spark, `json_tuple` is a Generator that produces multiple columns directly.
39/// In DataFusion, a ScalarUDF can only return one value per row, so the result is wrapped
40/// in a Struct. The caller (e.g. Comet) is expected to destructure the struct fields.
41///
42/// - Returns NULL for each field that is missing from the JSON object
43/// - Returns NULL for all fields if the input is NULL or not valid JSON
44/// - Non-string JSON values are converted to their JSON string representation
45/// - JSON `null` values are returned as NULL (not the string "null")
46#[derive(Debug, PartialEq, Eq, Hash)]
47pub struct JsonTuple {
48    signature: Signature,
49}
50
51impl Default for JsonTuple {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl JsonTuple {
58    pub fn new() -> Self {
59        Self {
60            signature: Signature::variadic(vec![DataType::Utf8], Volatility::Immutable),
61        }
62    }
63}
64
65impl ScalarUDFImpl for JsonTuple {
66    fn as_any(&self) -> &dyn Any {
67        self
68    }
69
70    fn name(&self) -> &str {
71        "json_tuple"
72    }
73
74    fn signature(&self) -> &Signature {
75        &self.signature
76    }
77
78    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
79        internal_err!("return_field_from_args should be used instead")
80    }
81
82    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
83        if args.arg_fields.len() < 2 {
84            return exec_err!(
85                "json_tuple requires at least 2 arguments (json_string, field1), got {}",
86                args.arg_fields.len()
87            );
88        }
89
90        let num_fields = args.arg_fields.len() - 1;
91        let fields: Fields = (0..num_fields)
92            .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
93            .collect::<Vec<_>>()
94            .into();
95
96        Ok(Arc::new(Field::new(
97            self.name(),
98            DataType::Struct(fields),
99            true,
100        )))
101    }
102
103    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104        let ScalarFunctionArgs {
105            args: arg_values,
106            return_field,
107            ..
108        } = args;
109        let arrays = ColumnarValue::values_to_arrays(&arg_values)?;
110        let result = json_tuple_inner(&arrays, return_field.data_type())?;
111
112        Ok(ColumnarValue::Array(result))
113    }
114}
115
116fn json_tuple_inner(args: &[ArrayRef], return_type: &DataType) -> Result<ArrayRef> {
117    let num_rows = args[0].len();
118    let num_fields = args.len() - 1;
119
120    let json_array = as_string_array(&args[0])?;
121
122    let field_arrays = args[1..]
123        .iter()
124        .map(|arg| as_string_array(arg))
125        .collect::<Result<Vec<_>>>()?;
126
127    let mut builders: Vec<StringBuilder> =
128        (0..num_fields).map(|_| StringBuilder::new()).collect();
129
130    let mut null_buffer = NullBufferBuilder::new(num_rows);
131
132    for row_idx in 0..num_rows {
133        if json_array.is_null(row_idx) {
134            for builder in &mut builders {
135                builder.append_null();
136            }
137            null_buffer.append_null();
138            continue;
139        }
140
141        let json_str = json_array.value(row_idx);
142        match serde_json::from_str::<serde_json::Value>(json_str) {
143            Ok(serde_json::Value::Object(map)) => {
144                null_buffer.append_non_null();
145                for (field_idx, builder) in builders.iter_mut().enumerate() {
146                    if field_arrays[field_idx].is_null(row_idx) {
147                        builder.append_null();
148                        continue;
149                    }
150                    let field_name = field_arrays[field_idx].value(row_idx);
151                    match map.get(field_name) {
152                        Some(serde_json::Value::Null) => {
153                            builder.append_null();
154                        }
155                        Some(serde_json::Value::String(s)) => {
156                            builder.append_value(s);
157                        }
158                        Some(other) => {
159                            builder.append_value(other.to_string());
160                        }
161                        None => {
162                            builder.append_null();
163                        }
164                    }
165                }
166            }
167            _ => {
168                for builder in &mut builders {
169                    builder.append_null();
170                }
171                null_buffer.append_null();
172            }
173        }
174    }
175
176    let struct_fields = match return_type {
177        DataType::Struct(fields) => fields.clone(),
178        _ => {
179            return internal_err!(
180                "json_tuple requires a Struct return type, got {:?}",
181                return_type
182            );
183        }
184    };
185
186    let arrays: Vec<ArrayRef> = builders
187        .into_iter()
188        .map(|mut builder| Arc::new(builder.finish()) as ArrayRef)
189        .collect();
190
191    let struct_array = StructArray::try_new(struct_fields, arrays, null_buffer.finish())?;
192
193    Ok(Arc::new(struct_array))
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use datafusion_expr::ReturnFieldArgs;
200
201    #[test]
202    fn test_return_field_shape() {
203        let func = JsonTuple::new();
204        let fields = vec![
205            Arc::new(Field::new("json", DataType::Utf8, false)),
206            Arc::new(Field::new("f1", DataType::Utf8, false)),
207            Arc::new(Field::new("f2", DataType::Utf8, false)),
208        ];
209        let result = func
210            .return_field_from_args(ReturnFieldArgs {
211                arg_fields: &fields,
212                scalar_arguments: &[None, None, None],
213            })
214            .unwrap();
215
216        match result.data_type() {
217            DataType::Struct(inner) => {
218                assert_eq!(inner.len(), 2);
219                assert_eq!(inner[0].name(), "c0");
220                assert_eq!(inner[1].name(), "c1");
221                assert_eq!(inner[0].data_type(), &DataType::Utf8);
222                assert!(inner[0].is_nullable());
223            }
224            other => panic!("Expected Struct, got {other:?}"),
225        }
226    }
227
228    #[test]
229    fn test_too_few_args() {
230        let func = JsonTuple::new();
231        let fields = vec![Arc::new(Field::new("json", DataType::Utf8, false))];
232        let result = func.return_field_from_args(ReturnFieldArgs {
233            arg_fields: &fields,
234            scalar_arguments: &[None],
235        });
236        assert!(result.is_err());
237        assert!(
238            result
239                .unwrap_err()
240                .to_string()
241                .contains("at least 2 arguments")
242        );
243    }
244}