1use std::collections::HashMap;
2use std::sync::{Arc, LazyLock, OnceLock};
3
4use datafusion::arrow::array::{
5 Array, ArrayRef, AsArray, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray,
6};
7use datafusion::arrow::buffer::{Buffer, ScalarBuffer};
8use datafusion::arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
9use datafusion::arrow::error::ArrowError;
10use datafusion::common::ScalarValue;
11
12pub fn is_json_union(data_type: &DataType) -> bool {
13 match data_type {
14 DataType::Union(fields, UnionMode::Sparse) => fields == &union_fields(),
15 _ => false,
16 }
17}
18
19pub(crate) fn nested_json_array(array: &ArrayRef, object_lookup: bool) -> Option<&StringArray> {
26 nested_json_array_ref(array, object_lookup).map(AsArray::as_string)
27}
28
29pub(crate) fn nested_json_array_ref(array: &ArrayRef, object_lookup: bool) -> Option<&ArrayRef> {
30 let union_array: &UnionArray = array.as_any().downcast_ref::<UnionArray>()?;
31 let type_id = if object_lookup { TYPE_ID_OBJECT } else { TYPE_ID_ARRAY };
32 Some(union_array.child(type_id))
33}
34
35pub(crate) fn json_from_union_scalar<'a>(
37 type_id_value: Option<&'a (i8, Box<ScalarValue>)>,
38 fields: &UnionFields,
39) -> Option<&'a str> {
40 if let Some((type_id, value)) = type_id_value {
41 if fields == &union_fields() && (*type_id == TYPE_ID_ARRAY || *type_id == TYPE_ID_OBJECT) {
43 if let ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) = value.as_ref() {
44 return s.as_deref();
45 }
46 }
47 }
48 None
49}
50
51pub static JSON_UNION_DATA_TYPE: LazyLock<DataType> = LazyLock::new(JsonUnion::data_type);
52
53#[derive(Debug)]
54pub(crate) struct JsonUnion {
55 bools: Vec<Option<bool>>,
56 ints: Vec<Option<i64>>,
57 floats: Vec<Option<f64>>,
58 strings: Vec<Option<String>>,
59 arrays: Vec<Option<String>>,
60 objects: Vec<Option<String>>,
61 type_ids: Vec<i8>,
62 index: usize,
63 length: usize,
64}
65
66impl JsonUnion {
67 pub fn new(length: usize) -> Self {
68 Self {
69 bools: vec![None; length],
70 ints: vec![None; length],
71 floats: vec![None; length],
72 strings: vec![None; length],
73 arrays: vec![None; length],
74 objects: vec![None; length],
75 type_ids: vec![TYPE_ID_NULL; length],
76 index: 0,
77 length,
78 }
79 }
80
81 pub fn data_type() -> DataType {
82 DataType::Union(union_fields(), UnionMode::Sparse)
83 }
84
85 pub fn push(&mut self, field: JsonUnionField) {
86 self.type_ids[self.index] = field.type_id();
87 match field {
88 JsonUnionField::JsonNull => (),
89 JsonUnionField::Bool(value) => self.bools[self.index] = Some(value),
90 JsonUnionField::Int(value) => self.ints[self.index] = Some(value),
91 JsonUnionField::Float(value) => self.floats[self.index] = Some(value),
92 JsonUnionField::Str(value) => self.strings[self.index] = Some(value),
93 JsonUnionField::Array(value) => self.arrays[self.index] = Some(value),
94 JsonUnionField::Object(value) => self.objects[self.index] = Some(value),
95 }
96 self.index += 1;
97 debug_assert!(self.index <= self.length);
98 }
99
100 pub fn push_none(&mut self) {
101 self.index += 1;
102 debug_assert!(self.index <= self.length);
103 }
104}
105
106impl FromIterator<Option<JsonUnionField>> for JsonUnion {
108 fn from_iter<I: IntoIterator<Item = Option<JsonUnionField>>>(iter: I) -> Self {
109 let inner = iter.into_iter();
110 let (lower, upper) = inner.size_hint();
111 let mut union = Self::new(upper.unwrap_or(lower));
112
113 for opt_field in inner {
114 if let Some(union_field) = opt_field {
115 union.push(union_field);
116 } else {
117 union.push_none();
118 }
119 }
120 union
121 }
122}
123
124impl TryFrom<JsonUnion> for UnionArray {
125 type Error = ArrowError;
126
127 fn try_from(value: JsonUnion) -> Result<Self, Self::Error> {
128 let children: Vec<Arc<dyn Array>> = vec![
129 Arc::new(NullArray::new(value.length)),
130 Arc::new(BooleanArray::from(value.bools)),
131 Arc::new(Int64Array::from(value.ints)),
132 Arc::new(Float64Array::from(value.floats)),
133 Arc::new(StringArray::from(value.strings)),
134 Arc::new(StringArray::from(value.arrays)),
135 Arc::new(StringArray::from(value.objects)),
136 ];
137 UnionArray::try_new(union_fields(), Buffer::from_vec(value.type_ids).into(), None, children)
138 }
139}
140
141#[derive(Debug)]
142pub(crate) enum JsonUnionField {
143 JsonNull,
144 Bool(bool),
145 Int(i64),
146 Float(f64),
147 Str(String),
148 Array(String),
149 Object(String),
150}
151
152pub(crate) const TYPE_ID_NULL: i8 = 0;
153const TYPE_ID_BOOL: i8 = 1;
154const TYPE_ID_INT: i8 = 2;
155const TYPE_ID_FLOAT: i8 = 3;
156const TYPE_ID_STR: i8 = 4;
157const TYPE_ID_ARRAY: i8 = 5;
158const TYPE_ID_OBJECT: i8 = 6;
159
160fn union_fields() -> UnionFields {
161 static FIELDS: OnceLock<UnionFields> = OnceLock::new();
162 FIELDS
163 .get_or_init(|| {
164 let json_metadata: HashMap<String, String> =
165 HashMap::from_iter(vec![("is_json".to_string(), "true".to_string())]);
166 UnionFields::from_iter([
167 (TYPE_ID_NULL, Arc::new(Field::new("null", DataType::Null, true))),
168 (TYPE_ID_BOOL, Arc::new(Field::new("bool", DataType::Boolean, false))),
169 (TYPE_ID_INT, Arc::new(Field::new("int", DataType::Int64, false))),
170 (TYPE_ID_FLOAT, Arc::new(Field::new("float", DataType::Float64, false))),
171 (TYPE_ID_STR, Arc::new(Field::new("str", DataType::Utf8, false))),
172 (
173 TYPE_ID_ARRAY,
174 Arc::new(Field::new("array", DataType::Utf8, false).with_metadata(json_metadata.clone())),
175 ),
176 (
177 TYPE_ID_OBJECT,
178 Arc::new(Field::new("object", DataType::Utf8, false).with_metadata(json_metadata.clone())),
179 ),
180 ])
181 })
182 .clone()
183}
184
185impl JsonUnionField {
186 fn type_id(&self) -> i8 {
187 match self {
188 Self::JsonNull => TYPE_ID_NULL,
189 Self::Bool(_) => TYPE_ID_BOOL,
190 Self::Int(_) => TYPE_ID_INT,
191 Self::Float(_) => TYPE_ID_FLOAT,
192 Self::Str(_) => TYPE_ID_STR,
193 Self::Array(_) => TYPE_ID_ARRAY,
194 Self::Object(_) => TYPE_ID_OBJECT,
195 }
196 }
197
198 pub fn scalar_value(f: Option<Self>) -> ScalarValue {
199 ScalarValue::Union(
200 f.map(|f| (f.type_id(), Box::new(f.into()))),
201 union_fields(),
202 UnionMode::Sparse,
203 )
204 }
205}
206
207impl From<JsonUnionField> for ScalarValue {
208 fn from(value: JsonUnionField) -> Self {
209 match value {
210 JsonUnionField::JsonNull => Self::Null,
211 JsonUnionField::Bool(b) => Self::Boolean(Some(b)),
212 JsonUnionField::Int(i) => Self::Int64(Some(i)),
213 JsonUnionField::Float(f) => Self::Float64(Some(f)),
214 JsonUnionField::Str(s) | JsonUnionField::Array(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)),
215 }
216 }
217}
218
219pub struct JsonUnionEncoder {
220 boolean: BooleanArray,
221 int: Int64Array,
222 float: Float64Array,
223 string: StringArray,
224 array: StringArray,
225 object: StringArray,
226 type_ids: ScalarBuffer<i8>,
227}
228
229impl JsonUnionEncoder {
230 #[must_use]
231 pub fn from_union(union: UnionArray) -> Option<Self> {
232 if is_json_union(union.data_type()) {
233 let (_, type_ids, _, c) = union.into_parts();
234 Some(Self {
235 boolean: c[1].as_boolean().clone(),
236 int: c[2].as_primitive().clone(),
237 float: c[3].as_primitive().clone(),
238 string: c[4].as_string().clone(),
239 array: c[5].as_string().clone(),
240 object: c[6].as_string().clone(),
241 type_ids,
242 })
243 } else {
244 None
245 }
246 }
247
248 #[must_use]
249 #[allow(clippy::len_without_is_empty)]
250 pub fn len(&self) -> usize {
251 self.type_ids.len()
252 }
253
254 #[must_use]
260 pub fn get_value(&self, idx: usize) -> JsonUnionValue<'_> {
261 let type_id = self.type_ids[idx];
262 match type_id {
263 TYPE_ID_NULL => JsonUnionValue::JsonNull,
264 TYPE_ID_BOOL => JsonUnionValue::Bool(self.boolean.value(idx)),
265 TYPE_ID_INT => JsonUnionValue::Int(self.int.value(idx)),
266 TYPE_ID_FLOAT => JsonUnionValue::Float(self.float.value(idx)),
267 TYPE_ID_STR => JsonUnionValue::Str(self.string.value(idx)),
268 TYPE_ID_ARRAY => JsonUnionValue::Array(self.array.value(idx)),
269 TYPE_ID_OBJECT => JsonUnionValue::Object(self.object.value(idx)),
270 _ => panic!("Invalid type_id: {type_id}, not a valid JSON type"),
271 }
272 }
273}
274
275#[derive(Debug, PartialEq)]
276pub enum JsonUnionValue<'a> {
277 JsonNull,
278 Bool(bool),
279 Int(i64),
280 Float(f64),
281 Str(&'a str),
282 Array(&'a str),
283 Object(&'a str),
284}
285
286#[cfg(test)]
287mod test {
288 use super::*;
289
290 #[test]
291 fn test_json_union() {
292 let json_union = JsonUnion::from_iter(vec![
293 Some(JsonUnionField::JsonNull),
294 Some(JsonUnionField::Bool(true)),
295 Some(JsonUnionField::Bool(false)),
296 Some(JsonUnionField::Int(42)),
297 Some(JsonUnionField::Float(42.0)),
298 Some(JsonUnionField::Str("foo".to_string())),
299 Some(JsonUnionField::Array("[42]".to_string())),
300 Some(JsonUnionField::Object(r#"{"foo": 42}"#.to_string())),
301 None,
302 ]);
303
304 let union_array = UnionArray::try_from(json_union).unwrap();
305 let encoder = JsonUnionEncoder::from_union(union_array).unwrap();
306
307 let values_after: Vec<_> = (0..encoder.len()).map(|idx| encoder.get_value(idx)).collect();
308 assert_eq!(
309 values_after,
310 vec![
311 JsonUnionValue::JsonNull,
312 JsonUnionValue::Bool(true),
313 JsonUnionValue::Bool(false),
314 JsonUnionValue::Int(42),
315 JsonUnionValue::Float(42.0),
316 JsonUnionValue::Str("foo"),
317 JsonUnionValue::Array("[42]"),
318 JsonUnionValue::Object(r#"{"foo": 42}"#),
319 JsonUnionValue::JsonNull,
320 ]
321 );
322 }
323}