Skip to main content

datafusion_spark/function/json/
json_tuple.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use arrow::array::{Array, ArrayRef, NullBufferBuilder, StringBuilder, StructArray};
21use arrow::datatypes::{DataType, Field, FieldRef, Fields};
22use datafusion_common::cast::as_string_array;
23use datafusion_common::{Result, exec_err, internal_err};
24use datafusion_expr::{
25    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
26    Volatility,
27};
28
29/// Spark-compatible `json_tuple` expression
30///
31/// <https://spark.apache.org/docs/latest/api/sql/index.html#json_tuple>
32///
33/// Extracts top-level fields from a JSON string and returns them as a struct.
34///
35/// `json_tuple(json_string, field1, field2, ...) -> Struct<c0: Utf8, c1: Utf8, ...>`
36///
37/// Note: In Spark, `json_tuple` is a Generator that produces multiple columns directly.
38/// In DataFusion, a ScalarUDF can only return one value per row, so the result is wrapped
39/// in a Struct. The caller (e.g. Comet) is expected to destructure the struct fields.
40///
41/// - Returns NULL for each field that is missing from the JSON object
42/// - Returns NULL for all fields if the input is NULL or not valid JSON
43/// - Non-string JSON values are converted to their JSON string representation
44/// - JSON `null` values are returned as NULL (not the string "null")
45#[derive(Debug, PartialEq, Eq, Hash)]
46pub struct JsonTuple {
47    signature: Signature,
48}
49
50impl Default for JsonTuple {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl JsonTuple {
57    pub fn new() -> Self {
58        Self {
59            signature: Signature::variadic(vec![DataType::Utf8], Volatility::Immutable),
60        }
61    }
62}
63
64impl ScalarUDFImpl for JsonTuple {
65    fn name(&self) -> &str {
66        "json_tuple"
67    }
68
69    fn signature(&self) -> &Signature {
70        &self.signature
71    }
72
73    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74        internal_err!("return_field_from_args should be used instead")
75    }
76
77    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
78        if args.arg_fields.len() < 2 {
79            return exec_err!(
80                "json_tuple requires at least 2 arguments (json_string, field1), got {}",
81                args.arg_fields.len()
82            );
83        }
84
85        let num_fields = args.arg_fields.len() - 1;
86        let fields: Fields = (0..num_fields)
87            .map(|i| Field::new(format!("c{i}"), DataType::Utf8, true))
88            .collect::<Vec<_>>()
89            .into();
90
91        Ok(Arc::new(Field::new(
92            self.name(),
93            DataType::Struct(fields),
94            true,
95        )))
96    }
97
98    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99        let ScalarFunctionArgs {
100            args: arg_values,
101            return_field,
102            ..
103        } = args;
104        let arrays = ColumnarValue::values_to_arrays(&arg_values)?;
105        let result = json_tuple_inner(&arrays, return_field.data_type())?;
106
107        Ok(ColumnarValue::Array(result))
108    }
109}
110
111fn json_tuple_inner(args: &[ArrayRef], return_type: &DataType) -> Result<ArrayRef> {
112    let num_rows = args[0].len();
113    let num_fields = args.len() - 1;
114
115    let json_array = as_string_array(&args[0])?;
116
117    let field_arrays = args[1..]
118        .iter()
119        .map(|arg| as_string_array(arg))
120        .collect::<Result<Vec<_>>>()?;
121
122    let mut builders: Vec<StringBuilder> =
123        (0..num_fields).map(|_| StringBuilder::new()).collect();
124
125    let mut null_buffer = NullBufferBuilder::new(num_rows);
126
127    for row_idx in 0..num_rows {
128        if json_array.is_null(row_idx) {
129            for builder in &mut builders {
130                builder.append_null();
131            }
132            null_buffer.append_null();
133            continue;
134        }
135
136        let json_str = json_array.value(row_idx);
137        match serde_json::from_str::<serde_json::Value>(json_str) {
138            Ok(serde_json::Value::Object(map)) => {
139                null_buffer.append_non_null();
140                for (field_idx, builder) in builders.iter_mut().enumerate() {
141                    if field_arrays[field_idx].is_null(row_idx) {
142                        builder.append_null();
143                        continue;
144                    }
145                    let field_name = field_arrays[field_idx].value(row_idx);
146                    match map.get(field_name) {
147                        Some(serde_json::Value::Null) => {
148                            builder.append_null();
149                        }
150                        Some(serde_json::Value::String(s)) => {
151                            builder.append_value(s);
152                        }
153                        Some(other) => {
154                            builder.append_value(other.to_string());
155                        }
156                        None => {
157                            builder.append_null();
158                        }
159                    }
160                }
161            }
162            _ => {
163                for builder in &mut builders {
164                    builder.append_null();
165                }
166                null_buffer.append_null();
167            }
168        }
169    }
170
171    let struct_fields = match return_type {
172        DataType::Struct(fields) => fields.clone(),
173        _ => {
174            return internal_err!(
175                "json_tuple requires a Struct return type, got {:?}",
176                return_type
177            );
178        }
179    };
180
181    let arrays: Vec<ArrayRef> = builders
182        .into_iter()
183        .map(|mut builder| Arc::new(builder.finish()) as ArrayRef)
184        .collect();
185
186    let struct_array = StructArray::try_new(struct_fields, arrays, null_buffer.finish())?;
187
188    Ok(Arc::new(struct_array))
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_return_field_shape() {
197        let func = JsonTuple::new();
198        let fields = vec![
199            Arc::new(Field::new("json", DataType::Utf8, false)),
200            Arc::new(Field::new("f1", DataType::Utf8, false)),
201            Arc::new(Field::new("f2", DataType::Utf8, false)),
202        ];
203        let result = func
204            .return_field_from_args(ReturnFieldArgs {
205                arg_fields: &fields,
206                scalar_arguments: &[None, None, None],
207            })
208            .unwrap();
209
210        match result.data_type() {
211            DataType::Struct(inner) => {
212                assert_eq!(inner.len(), 2);
213                assert_eq!(inner[0].name(), "c0");
214                assert_eq!(inner[1].name(), "c1");
215                assert_eq!(inner[0].data_type(), &DataType::Utf8);
216                assert!(inner[0].is_nullable());
217            }
218            other => panic!("Expected Struct, got {other:?}"),
219        }
220    }
221
222    #[test]
223    fn test_too_few_args() {
224        let func = JsonTuple::new();
225        let fields = vec![Arc::new(Field::new("json", DataType::Utf8, false))];
226        let result = func.return_field_from_args(ReturnFieldArgs {
227            arg_fields: &fields,
228            scalar_arguments: &[None],
229        });
230        assert!(result.is_err());
231        assert!(
232            result
233                .unwrap_err()
234                .to_string()
235                .contains("at least 2 arguments")
236        );
237    }
238}