Skip to main content

datafusion_functions_json/
common_union.rs

1use std::collections::HashMap;
2use std::sync::{Arc, LazyLock, OnceLock};
3
4use datafusion::arrow::array::{
5    Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray,
6};
7use datafusion::arrow::buffer::{Buffer, ScalarBuffer};
8use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
9use datafusion::arrow::error::ArrowError;
10use datafusion::common::ScalarValue;
11
12/// Field metadata used to mark a `Utf8` field as containing raw JSON.
13///
14/// Attach this to any Arrow `Field` whose values are JSON-encoded strings so
15/// downstream consumers can recognize them as JSON rather than opaque text.
16///
17/// Emits Arrow's canonical JSON extension type keys
18/// (`ARROW:extension:name` = `arrow.json`, `ARROW:extension:metadata` = `{}`),
19/// see <https://arrow.apache.org/docs/format/CanonicalExtensions.html#json>.
20///
21/// Also emits a legacy `is_json` = `true` key. This key predates this crate's
22/// adoption of the canonical extension and is non-standard — no other Arrow
23/// tool recognizes it. It is kept only for back-compat with existing
24/// downstream consumers of this crate and will be removed in a future
25/// release; new consumers should key off `ARROW:extension:name` instead.
26#[must_use]
27pub fn json_field_metadata() -> HashMap<String, String> {
28    HashMap::from([
29        ("ARROW:extension:name".to_string(), "arrow.json".to_string()),
30        ("ARROW:extension:metadata".to_string(), "{}".to_string()),
31        // Legacy, non-standard. Remove in a future release — see doc comment above.
32        ("is_json".to_string(), "true".to_string()),
33    ])
34}
35
36pub fn is_json_union(data_type: &DataType) -> bool {
37    match data_type {
38        DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(),
39        _ => false,
40    }
41}
42
43/// Extract nested JSON from a `JsonUnion` `UnionArray`
44///
45/// # Arguments
46/// * `array` - The `UnionArray` to extract the nested JSON from
47/// * `object_lookup` - If `true`, extract from the "object" member of the union,
48///   otherwise extract from the "array" member
49pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> {
50    nested_json_array_ref(array, object_lookup).map(AsArray::as_string)
51}
52
53pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> {
54    let union_array: &UnionArray = array.as_any().downcast_ref::<UnionArray>()?;
55    let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY };
56    Some(union_array.child(type_id))
57}
58
59/// Extract a JSON string from a `JsonUnion` scalar
60pub(crate) fn json_from_union_scalar<'a>(
61    type_id_value: Option<&'a (i8, Box<ScalarValue>)>,
62    fields: &UnionFields,
63) -> Option<&'a str> {
64    if let Some((type_id, value)) = type_id_value {
65        // we only want to take the ScalarValue string if the type_id indicates the value represents nested JSON
66        if fields == &union_fields() && (*type_id == TYPE_ID_ARRAY || *type_id == TYPE_ID_OBJECT) {
67            if let ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) = value.as_ref() {
68                return s.as_deref();
69            }
70        }
71    }
72    None
73}
74
75pub static JSON_UNION_DATA_TYPE: LazyLock<DataType> = LazyLock::new(JsonUnion::data_type);
76
77#[derive(Debug)]
78pub(crate) struct JsonUnion {
79    bools: Vec<Option<bool>>,
80    ints: Vec<Option<i64>>,
81    floats: Vec<Option<f64>>,
82    strings: Vec<Option<String>>,
83    arrays: Vec<Option<String>>,
84    objects: Vec<Option<String>>,
85    type_ids: Vec<i8>,
86    index: usize,
87    length: usize,
88}
89
90impl JsonUnion {
91    pub fn new(length: usize) -> Self {
92        Self {
93            bools: vec![None; length],
94            ints: vec![None; length],
95            floats: vec![None; length],
96            strings: vec![None; length],
97            arrays: vec![None; length],
98            objects: vec![None; length],
99            type_ids: vec![TYPE_ID_NULL; length],
100            index: 0,
101            length,
102        }
103    }
104
105    pub fn data_type() -> DataType {
106        DataType::Union(union_fields(), UnionMode::Sparse)
107    }
108
109    pub fn push(&mut self, field: JsonUnionField) {
110        self.type_ids[self.index] = field.type_id();
111        match field {
112            JsonUnionField::JsonNull => (),
113            JsonUnionField::Bool(value) => self.bools[self.index] = Some(value),
114            JsonUnionField::Int(value) => self.ints[self.index] = Some(value),
115            JsonUnionField::Float(value) => self.floats[self.index] = Some(value),
116            JsonUnionField::Str(value) => self.strings[self.index] = Some(value),
117            JsonUnionField::Array(value) => self.arrays[self.index] = Some(value),
118            JsonUnionField::Object(value) => self.objects[self.index] = Some(value),
119        }
120        self.index += 1;
121        debug_assert!(self.index <= self.length);
122    }
123
124    pub fn push_none(&mut self) {
125        self.index += 1;
126        debug_assert!(self.index <= self.length);
127    }
128}
129
130/// So we can do `collect::<JsonUnion>()`
131impl FromIterator<Option<JsonUnionField>> for JsonUnion {
132    fn from_iter<I: IntoIterator<Item = Option<JsonUnionField>>>(iter: I) -> Self {
133        let inner = iter.into_iter();
134        let (lower, upper) = inner.size_hint();
135        let mut union = Self::new(upper.unwrap_or(lower));
136
137        for opt_field in inner {
138            if let Some(union_field) = opt_field {
139                union.push(union_field);
140            } else {
141                union.push_none();
142            }
143        }
144        union
145    }
146}
147
148impl TryFrom<JsonUnion> for UnionArray {
149    type Error = ArrowError;
150
151    fn try_from(value: JsonUnion) -> Result<Self, Self::Error> {
152        let children: Vec<Arc<dyn Array>> = vec![
153            Arc::new(NullArray::new(value.length)),
154            Arc::new(BooleanArray::from(value.bools)),
155            Arc::new(Int64Array::from(value.ints)),
156            Arc::new(Float64Array::from(value.floats)),
157            Arc::new(StringArray::from(value.strings)),
158            Arc::new(StringArray::from(value.arrays)),
159            Arc::new(StringArray::from(value.objects)),
160        ];
161        UnionArray::try_new(union_fields(), Buffer::from_vec(value.type_ids).into(), None, children)
162    }
163}
164
165#[derive(Debug)]
166pub(crate) enum JsonUnionField {
167    JsonNull,
168    Bool(bool),
169    Int(i64),
170    Float(f64),
171    Str(String),
172    Array(String),
173    Object(String),
174}
175
176pub(crate) const TYPE_ID_NULL: i8 = 0;
177const TYPE_ID_BOOL: i8 = 1;
178const TYPE_ID_INT: i8 = 2;
179const TYPE_ID_FLOAT: i8 = 3;
180const TYPE_ID_STR: i8 = 4;
181const TYPE_ID_ARRAY: i8 = 5;
182const TYPE_ID_OBJECT: i8 = 6;
183
184fn union_fields() -> UnionFields {
185    static FIELDS: OnceLock<UnionFields> = OnceLock::new();
186    FIELDS
187        .get_or_init(|| {
188            UnionFields::from_iter([
189                (TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Null, true))),
190                (TYPE_ID_BOOL, Arc::new(Field::new("bool", DataType::Boolean, false))),
191                (TYPE_ID_INT, Arc::new(Field::new("int", DataType::Int64, false))),
192                (TYPE_ID_FLOAT, Arc::new(Field::new("float", DataType::Float64, false))),
193                (TYPE_ID_STR, Arc::new(Field::new("str", DataType::Utf8, false))),
194                (
195                    TYPE_ID_ARRAY,
196                    Arc::new(Field::new("array", DataType::Utf8, false).with_metadata(json_field_metadata())),
197                ),
198                (
199                    TYPE_ID_OBJECT,
200                    Arc::new(Field::new("object", DataType::Utf8, false).with_metadata(json_field_metadata())),
201                ),
202            ])
203        })
204        .clone()
205}
206
207impl JsonUnionField {
208    fn type_id(&self) -> i8 {
209        match self {
210            Self::JsonNull => TYPE_ID_NULL,
211            Self::Bool(_) => TYPE_ID_BOOL,
212            Self::Int(_) => TYPE_ID_INT,
213            Self::Float(_) => TYPE_ID_FLOAT,
214            Self::Str(_) => TYPE_ID_STR,
215            Self::Array(_) => TYPE_ID_ARRAY,
216            Self::Object(_) => TYPE_ID_OBJECT,
217        }
218    }
219
220    pub fn scalar_value(f: Option<Self>) -> ScalarValue {
221        ScalarValue::Union(
222            f.map(|f| (f.type_id(), Box::new(f.into()))),
223            union_fields(),
224            UnionMode::Sparse,
225        )
226    }
227}
228
229impl From<JsonUnionField> for ScalarValue {
230    fn from(value: JsonUnionField) -> Self {
231        match value {
232            JsonUnionField::JsonNull => Self::Null,
233            JsonUnionField::Bool(b) => Self::Boolean(Some(b)),
234            JsonUnionField::Int(i) => Self::Int64(Some(i)),
235            JsonUnionField::Float(f) => Self::Float64(Some(f)),
236            JsonUnionField::Str(s) | JsonUnionField::Array(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)),
237        }
238    }
239}
240
241pub struct JsonUnionEncoder {
242    boolean: BooleanArray,
243    int: Int64Array,
244    float: Float64Array,
245    string: StringArray,
246    array: StringArray,
247    object: StringArray,
248    type_ids: ScalarBuffer<i8>,
249}
250
251impl JsonUnionEncoder {
252    #[must_use]
253    pub fn from_union(union: UnionArray) -> Option<Self> {
254        if is_json_union(union.data_type()) {
255            let (_, type_ids, _, c) = union.into_parts();
256            Some(Self {
257                boolean: c[1].as_boolean().clone(),
258                int: c[2].as_primitive().clone(),
259                float: c[3].as_primitive().clone(),
260                string: c[4].as_string().clone(),
261                array: c[5].as_string().clone(),
262                object: c[6].as_string().clone(),
263                type_ids,
264            })
265        } else {
266            None
267        }
268    }
269
270    #[must_use]
271    #[allow(clippy::len_without_is_empty)]
272    pub fn len(&self) -> usize {
273        self.type_ids.len()
274    }
275
276    /// Get the encodable value for a given index
277    ///
278    /// # Panics
279    ///
280    /// Panics if the idx is outside the union values or an invalid type id exists in the union.
281    #[must_use]
282    pub fn get_value(&self, idx: usize) -> JsonUnionValue<'_> {
283        let type_id = self.type_ids[idx];
284        match type_id {
285            TYPE_ID_NULL => JsonUnionValue::JsonNull,
286            TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)),
287            TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)),
288            TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)),
289            TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)),
290            TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)),
291            TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)),
292            _ => panic!("Invalid type_id: {type_id}, not a valid JSON type"),
293        }
294    }
295}
296
297#[derive(Debug, PartialEq)]
298pub enum JsonUnionValue<'a> {
299    JsonNull,
300    Bool(bool),
301    Int(i64),
302    Float(f64),
303    Str(&'a str),
304    Array(&'a str),
305    Object(&'a str),
306}
307
308#[cfg(test)]
309mod test {
310    use super::*;
311
312    #[test]
313    fn test_json_union() {
314        let json_union = JsonUnion::from_iter(vec![
315            Some(JsonUnionField::JsonNull),
316            Some(JsonUnionField::Bool(true)),
317            Some(JsonUnionField::Bool(false)),
318            Some(JsonUnionField::Int(42)),
319            Some(JsonUnionField::Float(42.0)),
320            Some(JsonUnionField::Str("foo".to_string())),
321            Some(JsonUnionField::Array("[42]".to_string())),
322            Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())),
323            None,
324        ]);
325
326        let union_array = UnionArray::try_from(json_union).unwrap();
327        let encoder = JsonUnionEncoder::from_union(union_array).unwrap();
328
329        let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect();
330        assert_eq!(
331            values_after,
332            vec![
333                JsonUnionValue::JsonNull,
334                JsonUnionValue::Bool(true),
335                JsonUnionValue::Bool(false),
336                JsonUnionValue::Int(42),
337                JsonUnionValue::Float(42.0),
338                JsonUnionValue::Str("foo"),
339                JsonUnionValue::Array("[42]"),
340                JsonUnionValue::Object(r#"{"foo": 42}"#),
341                JsonUnionValue::JsonNull,
342            ]
343        );
344    }
345}