arrow_array/array/
union_array.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17#![allow(clippy::enum_clike_unportable_variant)]
18
19use crate::{make_array, Array, ArrayRef};
20use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21use arrow_buffer::buffer::NullBuffer;
22use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25/// Contains the `UnionArray` type.
26///
27use std::any::Any;
28use std::collections::HashSet;
29use std::sync::Arc;
30
31/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout)
32///
33/// Each slot in a [UnionArray] can have a value chosen from a number
34/// of types.  Each of the possible types are named like the fields of
35/// a [`StructArray`](crate::StructArray).  A `UnionArray` can
36/// have two possible memory layouts, "dense" or "sparse".  For more
37/// information on please see the
38/// [specification](https://arrow.apache.org/docs/format/Columnar.html#union-layout).
39///
40/// [UnionBuilder](crate::builder::UnionBuilder) can be used to
41/// create [UnionArray]'s of primitive types. `UnionArray`'s of nested
42/// types are also supported but not via `UnionBuilder`, see the tests
43/// for examples.
44///
45/// # Examples
46/// ## Create a dense UnionArray `[1, 3.2, 34]`
47/// ```
48/// use arrow_buffer::ScalarBuffer;
49/// use arrow_schema::*;
50/// use std::sync::Arc;
51/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
52///
53/// let int_array = Int32Array::from(vec![1, 34]);
54/// let float_array = Float64Array::from(vec![3.2]);
55/// let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
56/// let offsets = [0, 0, 1].into_iter().collect::<ScalarBuffer<i32>>();
57///
58/// let union_fields = [
59///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
60///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
61/// ].into_iter().collect::<UnionFields>();
62///
63/// let children = vec![
64///     Arc::new(int_array) as Arc<dyn Array>,
65///     Arc::new(float_array),
66/// ];
67///
68/// let array = UnionArray::try_new(
69///     union_fields,
70///     type_ids,
71///     Some(offsets),
72///     children,
73/// ).unwrap();
74///
75/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
76/// assert_eq!(1, value);
77///
78/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
79/// assert!(3.2 - value < f64::EPSILON);
80///
81/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
82/// assert_eq!(34, value);
83/// ```
84///
85/// ## Create a sparse UnionArray `[1, 3.2, 34]`
86/// ```
87/// use arrow_buffer::ScalarBuffer;
88/// use arrow_schema::*;
89/// use std::sync::Arc;
90/// use arrow_array::{Array, Int32Array, Float64Array, UnionArray};
91///
92/// let int_array = Int32Array::from(vec![Some(1), None, Some(34)]);
93/// let float_array = Float64Array::from(vec![None, Some(3.2), None]);
94/// let type_ids = [0_i8, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
95///
96/// let union_fields = [
97///     (0, Arc::new(Field::new("A", DataType::Int32, false))),
98///     (1, Arc::new(Field::new("B", DataType::Float64, false))),
99/// ].into_iter().collect::<UnionFields>();
100///
101/// let children = vec![
102///     Arc::new(int_array) as Arc<dyn Array>,
103///     Arc::new(float_array),
104/// ];
105///
106/// let array = UnionArray::try_new(
107///     union_fields,
108///     type_ids,
109///     None,
110///     children,
111/// ).unwrap();
112///
113/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
114/// assert_eq!(1, value);
115///
116/// let value = array.value(1).as_any().downcast_ref::<Float64Array>().unwrap().value(0);
117/// assert!(3.2 - value < f64::EPSILON);
118///
119/// let value = array.value(2).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
120/// assert_eq!(34, value);
121/// ```
122#[derive(Clone)]
123pub struct UnionArray {
124    data_type: DataType,
125    type_ids: ScalarBuffer<i8>,
126    offsets: Option<ScalarBuffer<i32>>,
127    fields: Vec<Option<ArrayRef>>,
128}
129
130impl UnionArray {
131    /// Creates a new `UnionArray`.
132    ///
133    /// Accepts type ids, child arrays and optionally offsets (for dense unions) to create
134    /// a new `UnionArray`.  This method makes no attempt to validate the data provided by the
135    /// caller and assumes that each of the components are correct and consistent with each other.
136    /// See `try_new` for an alternative that validates the data provided.
137    ///
138    /// # Safety
139    ///
140    /// The `type_ids` values should be positive and must match one of the type ids of the fields provided in `fields`.
141    /// These values are used to index into the `children` arrays.
142    ///
143    /// The `offsets` is provided in the case of a dense union, sparse unions should use `None`.
144    /// If provided the `offsets` values should be positive and must be less than the length of the
145    /// corresponding array.
146    ///
147    /// In both cases above we use signed integer types to maintain compatibility with other
148    /// Arrow implementations.
149    pub unsafe fn new_unchecked(
150        fields: UnionFields,
151        type_ids: ScalarBuffer<i8>,
152        offsets: Option<ScalarBuffer<i32>>,
153        children: Vec<ArrayRef>,
154    ) -> Self {
155        let mode = if offsets.is_some() {
156            UnionMode::Dense
157        } else {
158            UnionMode::Sparse
159        };
160
161        let len = type_ids.len();
162        let builder = ArrayData::builder(DataType::Union(fields, mode))
163            .add_buffer(type_ids.into_inner())
164            .child_data(children.into_iter().map(Array::into_data).collect())
165            .len(len);
166
167        let data = match offsets {
168            Some(offsets) => builder.add_buffer(offsets.into_inner()).build_unchecked(),
169            None => builder.build_unchecked(),
170        };
171        Self::from(data)
172    }
173
174    /// Attempts to create a new `UnionArray`, validating the inputs provided.
175    ///
176    /// The order of child arrays child array order must match the fields order
177    pub fn try_new(
178        fields: UnionFields,
179        type_ids: ScalarBuffer<i8>,
180        offsets: Option<ScalarBuffer<i32>>,
181        children: Vec<ArrayRef>,
182    ) -> Result<Self, ArrowError> {
183        // There must be a child array for every field.
184        if fields.len() != children.len() {
185            return Err(ArrowError::InvalidArgumentError(
186                "Union fields length must match child arrays length".to_string(),
187            ));
188        }
189
190        if let Some(offsets) = &offsets {
191            // There must be an offset value for every type id value.
192            if offsets.len() != type_ids.len() {
193                return Err(ArrowError::InvalidArgumentError(
194                    "Type Ids and Offsets lengths must match".to_string(),
195                ));
196            }
197        } else {
198            // Sparse union child arrays must be equal in length to the length of the union
199            for child in &children {
200                if child.len() != type_ids.len() {
201                    return Err(ArrowError::InvalidArgumentError(
202                        "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203                    ));
204                }
205            }
206        }
207
208        // Create mapping from type id to array lengths.
209        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210        let mut array_lens = vec![i32::MIN; max_id + 1];
211        for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212            array_lens[field_id as usize] = cd.len() as i32;
213        }
214
215        // Type id values must match one of the fields.
216        for id in &type_ids {
217            match array_lens.get(*id as usize) {
218                Some(x) if *x != i32::MIN => {}
219                _ => {
220                    return Err(ArrowError::InvalidArgumentError(
221                        "Type Ids values must match one of the field type ids".to_owned(),
222                    ))
223                }
224            }
225        }
226
227        // Check the value offsets are in bounds.
228        if let Some(offsets) = &offsets {
229            let mut iter = type_ids.iter().zip(offsets.iter());
230            if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231            {
232                return Err(ArrowError::InvalidArgumentError(
233                    "Offsets must be positive and within the length of the Array".to_owned(),
234                ));
235            }
236        }
237
238        // Safety:
239        // - Arguments validated above.
240        let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241        Ok(union_array)
242    }
243
244    /// Accesses the child array for `type_id`.
245    ///
246    /// # Panics
247    ///
248    /// Panics if the `type_id` provided is not present in the array's DataType
249    /// in the `Union`.
250    pub fn child(&self, type_id: i8) -> &ArrayRef {
251        assert!((type_id as usize) < self.fields.len());
252        let boxed = &self.fields[type_id as usize];
253        boxed.as_ref().expect("invalid type id")
254    }
255
256    /// Returns the `type_id` for the array slot at `index`.
257    ///
258    /// # Panics
259    ///
260    /// Panics if `index` is greater than or equal to the number of child arrays
261    pub fn type_id(&self, index: usize) -> i8 {
262        assert!(index < self.type_ids.len());
263        self.type_ids[index]
264    }
265
266    /// Returns the `type_ids` buffer for this array
267    pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268        &self.type_ids
269    }
270
271    /// Returns the `offsets` buffer if this is a dense array
272    pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273        self.offsets.as_ref()
274    }
275
276    /// Returns the offset into the underlying values array for the array slot at `index`.
277    ///
278    /// # Panics
279    ///
280    /// Panics if `index` is greater than or equal the length of the array.
281    pub fn value_offset(&self, index: usize) -> usize {
282        assert!(index < self.len());
283        match &self.offsets {
284            Some(offsets) => offsets[index] as usize,
285            None => self.offset() + index,
286        }
287    }
288
289    /// Returns the array's value at index `i`.
290    ///
291    /// Note: This method does not check for nulls and the value is arbitrary
292    /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index.
293    ///
294    /// # Panics
295    /// Panics if index `i` is out of bounds
296    pub fn value(&self, i: usize) -> ArrayRef {
297        let type_id = self.type_id(i);
298        let value_offset = self.value_offset(i);
299        let child = self.child(type_id);
300        child.slice(value_offset, 1)
301    }
302
303    /// Returns the names of the types in the union.
304    pub fn type_names(&self) -> Vec<&str> {
305        match self.data_type() {
306            DataType::Union(fields, _) => fields
307                .iter()
308                .map(|(_, f)| f.name().as_str())
309                .collect::<Vec<&str>>(),
310            _ => unreachable!("Union array's data type is not a union!"),
311        }
312    }
313
314    /// Returns whether the `UnionArray` is dense (or sparse if `false`).
315    fn is_dense(&self) -> bool {
316        match self.data_type() {
317            DataType::Union(_, mode) => mode == &UnionMode::Dense,
318            _ => unreachable!("Union array's data type is not a union!"),
319        }
320    }
321
322    /// Returns a zero-copy slice of this array with the indicated offset and length.
323    pub fn slice(&self, offset: usize, length: usize) -> Self {
324        let (offsets, fields) = match self.offsets.as_ref() {
325            // If dense union, slice offsets
326            Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
327            // Otherwise need to slice sparse children
328            None => {
329                let fields = self
330                    .fields
331                    .iter()
332                    .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
333                    .collect();
334                (None, fields)
335            }
336        };
337
338        Self {
339            data_type: self.data_type.clone(),
340            type_ids: self.type_ids.slice(offset, length),
341            offsets,
342            fields,
343        }
344    }
345
346    /// Deconstruct this array into its constituent parts
347    ///
348    /// # Example
349    ///
350    /// ```
351    /// # use arrow_array::array::UnionArray;
352    /// # use arrow_array::types::Int32Type;
353    /// # use arrow_array::builder::UnionBuilder;
354    /// # use arrow_buffer::ScalarBuffer;
355    /// # fn main() -> Result<(), arrow_schema::ArrowError> {
356    /// let mut builder = UnionBuilder::new_dense();
357    /// builder.append::<Int32Type>("a", 1).unwrap();
358    /// let union_array = builder.build()?;
359    ///
360    /// // Deconstruct into parts
361    /// let (union_fields, type_ids, offsets, children) = union_array.into_parts();
362    ///
363    /// // Reconstruct from parts
364    /// let union_array = UnionArray::try_new(
365    ///     union_fields,
366    ///     type_ids,
367    ///     offsets,
368    ///     children,
369    /// );
370    /// # Ok(())
371    /// # }
372    /// ```
373    #[allow(clippy::type_complexity)]
374    pub fn into_parts(
375        self,
376    ) -> (
377        UnionFields,
378        ScalarBuffer<i8>,
379        Option<ScalarBuffer<i32>>,
380        Vec<ArrayRef>,
381    ) {
382        let Self {
383            data_type,
384            type_ids,
385            offsets,
386            mut fields,
387        } = self;
388        match data_type {
389            DataType::Union(union_fields, _) => {
390                let children = union_fields
391                    .iter()
392                    .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
393                    .collect();
394                (union_fields, type_ids, offsets, children)
395            }
396            _ => unreachable!(),
397        }
398    }
399
400    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields without nulls
401    fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
402        // Example logic for a union with 5 fields, a, b & c with nulls, d & e without nulls:
403        // let [a_nulls, b_nulls, c_nulls] = nulls;
404        // let [is_a, is_b, is_c] = masks;
405        // let is_d_or_e = !(is_a | is_b | is_c)
406        // let union_chunk_nulls = is_d_or_e  | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
407        let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
408            (
409                with_nulls_selected | is_field,
410                union_nulls | (is_field & field_nulls),
411            )
412        };
413
414        self.mask_sparse_helper(
415            nulls,
416            |type_ids_chunk_array, nulls_masks_iters| {
417                let (with_nulls_selected, union_nulls) = nulls_masks_iters
418                    .iter_mut()
419                    .map(|(field_type_id, field_nulls)| {
420                        let field_nulls = field_nulls.next().unwrap();
421                        let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
422
423                        (is_field, field_nulls)
424                    })
425                    .fold((0, 0), fold);
426
427                // In the example above, this is the is_d_or_e = !(is_a | is_b) part
428                let without_nulls_selected = !with_nulls_selected;
429
430                // if a field without nulls is selected, the value is always true(set bit)
431                // otherwise, the true/set bits have been computed above
432                without_nulls_selected | union_nulls
433            },
434            |type_ids_remainder, bit_chunks| {
435                let (with_nulls_selected, union_nulls) = bit_chunks
436                    .iter()
437                    .map(|(field_type_id, field_bit_chunks)| {
438                        let field_nulls = field_bit_chunks.remainder_bits();
439                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
440
441                        (is_field, field_nulls)
442                    })
443                    .fold((0, 0), fold);
444
445                let without_nulls_selected = !with_nulls_selected;
446
447                without_nulls_selected | union_nulls
448            },
449        )
450    }
451
452    /// Computes the logical nulls for a sparse union, optimized for when there's a lot of fields fully null
453    fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
454        let fields = match self.data_type() {
455            DataType::Union(fields, _) => fields,
456            _ => unreachable!("Union array's data type is not a union!"),
457        };
458
459        let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
460        let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
461
462        let without_nulls_ids = type_ids
463            .difference(&with_nulls)
464            .copied()
465            .collect::<Vec<_>>();
466
467        nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
468
469        // Example logic for a union with 6 fields, a, b & c with nulls, d & e without nulls, and f fully_null:
470        // let [a_nulls, b_nulls, c_nulls] = nulls;
471        // let [is_a, is_b, is_c, is_d, is_e] = masks;
472        // let union_chunk_nulls = is_d | is_e | (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
473        self.mask_sparse_helper(
474            nulls,
475            |type_ids_chunk_array, nulls_masks_iters| {
476                let union_nulls = nulls_masks_iters.iter_mut().fold(
477                    0,
478                    |union_nulls, (field_type_id, nulls_iter)| {
479                        let field_nulls = nulls_iter.next().unwrap();
480
481                        if field_nulls == 0 {
482                            union_nulls
483                        } else {
484                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
485
486                            union_nulls | (is_field & field_nulls)
487                        }
488                    },
489                );
490
491                // Given the example above, this is the is_d_or_e = (is_d | is_e) part
492                let without_nulls_selected =
493                    without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
494
495                // if a field without nulls is selected, the value is always true(set bit)
496                // otherwise, the true/set bits have been computed above
497                union_nulls | without_nulls_selected
498            },
499            |type_ids_remainder, bit_chunks| {
500                let union_nulls =
501                    bit_chunks
502                        .iter()
503                        .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
504                            let is_field = selection_mask(type_ids_remainder, *field_type_id);
505                            let field_nulls = field_bit_chunks.remainder_bits();
506
507                            union_nulls | is_field & field_nulls
508                        });
509
510                union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
511            },
512        )
513    }
514
515    /// Computes the logical nulls for a sparse union, optimized for when all fields contains nulls
516    fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
517        // Example logic for a union with 3 fields, a, b & c, all containing nulls:
518        // let [a_nulls, b_nulls, c_nulls] = nulls;
519        // We can skip the first field: it's selection mask is the negation of all others selection mask
520        // let [is_b, is_c] = selection_masks;
521        // let is_a = !(is_b | is_c)
522        // let union_chunk_nulls = (is_a & a_nulls) | (is_b & b_nulls) | (is_c & c_nulls)
523        self.mask_sparse_helper(
524            nulls,
525            |type_ids_chunk_array, nulls_masks_iters| {
526                let (is_not_first, union_nulls) = nulls_masks_iters[1..] // skip first
527                    .iter_mut()
528                    .fold(
529                        (0, 0),
530                        |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
531                            let field_nulls = nulls_iter.next().unwrap();
532                            let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
533
534                            (
535                                is_not_first | is_field,
536                                union_nulls | (is_field & field_nulls),
537                            )
538                        },
539                    );
540
541                let is_first = !is_not_first;
542                let first_nulls = nulls_masks_iters[0].1.next().unwrap();
543
544                (is_first & first_nulls) | union_nulls
545            },
546            |type_ids_remainder, bit_chunks| {
547                bit_chunks
548                    .iter()
549                    .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
550                        let field_nulls = field_bit_chunks.remainder_bits();
551                        // The same logic as above, except that since this runs at most once,
552                        // it doesn't make difference to speed-up the first selection mask
553                        let is_field = selection_mask(type_ids_remainder, *field_type_id);
554
555                        union_nulls | (is_field & field_nulls)
556                    })
557            },
558        )
559    }
560
561    /// Maps `nulls` to `BitChunk's` and then to `BitChunkIterator's`, then divides `self.type_ids` into exact chunks of 64 values,
562    /// calling `mask_chunk` for every exact chunk, and `mask_remainder` for the remainder, if any, collecting the result in a `BooleanBuffer`
563    fn mask_sparse_helper(
564        &self,
565        nulls: Vec<(i8, NullBuffer)>,
566        mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
567        mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
568    ) -> BooleanBuffer {
569        let bit_chunks = nulls
570            .iter()
571            .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
572            .collect::<Vec<_>>();
573
574        let mut nulls_masks_iter = bit_chunks
575            .iter()
576            .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
577            .collect::<Vec<_>>();
578
579        let chunks_exact = self.type_ids.chunks_exact(64);
580        let remainder = chunks_exact.remainder();
581
582        let chunks = chunks_exact.map(|type_ids_chunk| {
583            let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
584
585            mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
586        });
587
588        // SAFETY:
589        // chunks is a ChunksExact iterator, which implements TrustedLen, and correctly reports its length
590        let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
591
592        if !remainder.is_empty() {
593            buffer.push(mask_remainder(remainder, &bit_chunks));
594        }
595
596        BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
597    }
598
599    /// Computes the logical nulls for a sparse or dense union, by gathering individual bits from the null buffer of the selected field
600    fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
601        let one_null = NullBuffer::new_null(1);
602        let one_valid = NullBuffer::new_valid(1);
603
604        // Unsafe code below depend on it:
605        // To remove one branch from the loop, if the a type_id is not utilized, or it's logical_nulls is None/all set,
606        // we use a null buffer of len 1 and a index_mask of 0, or the true null buffer and usize::MAX otherwise.
607        // We then unconditionally access the null buffer with index & index_mask,
608        // which always return 0 for the 1-len buffer, or the true index unchanged otherwise
609        // We also use a 256 array, so llvm knows that `type_id as u8 as usize` is always in bounds
610        let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
611
612        for (type_id, nulls) in &nulls {
613            if nulls.null_count() == nulls.len() {
614                // Similarly, if all values are null, use a 1-null null-buffer to reduce cache pressure a bit
615                logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
616            } else {
617                logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
618            }
619        }
620
621        match &self.offsets {
622            Some(offsets) => {
623                assert_eq!(self.type_ids.len(), offsets.len());
624
625                BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
626                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
627                    let type_id = *self.type_ids.get_unchecked(i);
628                    // SAFETY: We asserted that offsets len and self.type_ids len are equal
629                    let offset = *offsets.get_unchecked(i);
630
631                    let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
632
633                    // SAFETY:
634                    // If offset_mask is Max
635                    // 1. Offset validity is checked at union creation
636                    // 2. If the null buffer len equals it's array len is checked at array creation
637                    // If offset_mask is Zero, the null buffer len is 1
638                    nulls
639                        .inner()
640                        .value_unchecked(offset as usize & *offset_mask as usize)
641                })
642            }
643            None => {
644                BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
645                    // SAFETY: BooleanBuffer::collect_bool calls us 0..self.type_ids.len()
646                    let type_id = *self.type_ids.get_unchecked(index);
647
648                    let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
649
650                    // SAFETY:
651                    // If index_mask is Max
652                    // 1. On sparse union, every child len match it's parent, this is checked at union creation
653                    // 2. If the null buffer len equals it's array len is checked at array creation
654                    // If index_mask is Zero, the null buffer len is 1
655                    nulls.inner().value_unchecked(index & *index_mask as usize)
656                })
657            }
658        }
659    }
660
661    /// Returns a vector of tuples containing each field's type_id and its logical null buffer.
662    /// Only fields with non-zero null counts are included.
663    fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
664        self.fields
665            .iter()
666            .enumerate()
667            .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
668            .filter(|(_, nulls)| nulls.null_count() > 0)
669            .collect()
670    }
671}
672
673impl From<ArrayData> for UnionArray {
674    fn from(data: ArrayData) -> Self {
675        let (fields, mode) = match data.data_type() {
676            DataType::Union(fields, mode) => (fields, *mode),
677            d => panic!("UnionArray expected ArrayData with type Union got {d}"),
678        };
679        let (type_ids, offsets) = match mode {
680            UnionMode::Sparse => (
681                ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
682                None,
683            ),
684            UnionMode::Dense => (
685                ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
686                Some(ScalarBuffer::new(
687                    data.buffers()[1].clone(),
688                    data.offset(),
689                    data.len(),
690                )),
691            ),
692        };
693
694        let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
695        let mut boxed_fields = vec![None; max_id + 1];
696        for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
697            boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
698        }
699        Self {
700            data_type: data.data_type().clone(),
701            type_ids,
702            offsets,
703            fields: boxed_fields,
704        }
705    }
706}
707
708impl From<UnionArray> for ArrayData {
709    fn from(array: UnionArray) -> Self {
710        let len = array.len();
711        let f = match &array.data_type {
712            DataType::Union(f, _) => f,
713            _ => unreachable!(),
714        };
715        let buffers = match array.offsets {
716            Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
717            None => vec![array.type_ids.into_inner()],
718        };
719
720        let child = f
721            .iter()
722            .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
723            .collect();
724
725        let builder = ArrayDataBuilder::new(array.data_type)
726            .len(len)
727            .buffers(buffers)
728            .child_data(child);
729        unsafe { builder.build_unchecked() }
730    }
731}
732
733impl Array for UnionArray {
734    fn as_any(&self) -> &dyn Any {
735        self
736    }
737
738    fn to_data(&self) -> ArrayData {
739        self.clone().into()
740    }
741
742    fn into_data(self) -> ArrayData {
743        self.into()
744    }
745
746    fn data_type(&self) -> &DataType {
747        &self.data_type
748    }
749
750    fn slice(&self, offset: usize, length: usize) -> ArrayRef {
751        Arc::new(self.slice(offset, length))
752    }
753
754    fn len(&self) -> usize {
755        self.type_ids.len()
756    }
757
758    fn is_empty(&self) -> bool {
759        self.type_ids.is_empty()
760    }
761
762    fn shrink_to_fit(&mut self) {
763        self.type_ids.shrink_to_fit();
764        if let Some(offsets) = &mut self.offsets {
765            offsets.shrink_to_fit();
766        }
767        for array in self.fields.iter_mut().flatten() {
768            array.shrink_to_fit();
769        }
770        self.fields.shrink_to_fit();
771    }
772
773    fn offset(&self) -> usize {
774        0
775    }
776
777    fn nulls(&self) -> Option<&NullBuffer> {
778        None
779    }
780
781    fn logical_nulls(&self) -> Option<NullBuffer> {
782        let fields = match self.data_type() {
783            DataType::Union(fields, _) => fields,
784            _ => unreachable!(),
785        };
786
787        if fields.len() <= 1 {
788            return self.fields.iter().find_map(|field_opt| {
789                field_opt
790                    .as_ref()
791                    .and_then(|field| field.logical_nulls())
792                    .map(|logical_nulls| {
793                        if self.is_dense() {
794                            self.gather_nulls(vec![(0, logical_nulls)]).into()
795                        } else {
796                            logical_nulls
797                        }
798                    })
799            });
800        }
801
802        let logical_nulls = self.fields_logical_nulls();
803
804        if logical_nulls.is_empty() {
805            return None;
806        }
807
808        let fully_null_count = logical_nulls
809            .iter()
810            .filter(|(_, nulls)| nulls.null_count() == nulls.len())
811            .count();
812
813        if fully_null_count == fields.len() {
814            if let Some((_, exactly_sized)) = logical_nulls
815                .iter()
816                .find(|(_, nulls)| nulls.len() == self.len())
817            {
818                return Some(exactly_sized.clone());
819            }
820
821            if let Some((_, bigger)) = logical_nulls
822                .iter()
823                .find(|(_, nulls)| nulls.len() > self.len())
824            {
825                return Some(bigger.slice(0, self.len()));
826            }
827
828            return Some(NullBuffer::new_null(self.len()));
829        }
830
831        let boolean_buffer = match &self.offsets {
832            Some(_) => self.gather_nulls(logical_nulls),
833            None => {
834                // Choose the fastest way to compute the logical nulls
835                // Gather computes one null per iteration, while the others work on 64 nulls chunks,
836                // but must also compute selection masks, which is expensive,
837                // so it's cost is the number of selection masks computed per chunk
838                // Since computing the selection mask gets auto-vectorized, it's performance depends on which simd feature is enabled
839                // For gather, the cost is the threshold where masking becomes slower than gather, which is determined with benchmarks
840                // TODO: bench on avx512f(feature is still unstable)
841                let gather_relative_cost = if cfg!(target_feature = "avx2") {
842                    10
843                } else if cfg!(target_feature = "sse4.1") {
844                    3
845                } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
846                    // x86 baseline includes sse2
847                    2
848                } else {
849                    // TODO: bench on non x86
850                    // Always use gather on non benchmarked archs because even though it may slower on some cases,
851                    // it's performance depends only on the union length, without being affected by the number of fields
852                    0
853                };
854
855                let strategies = [
856                    (SparseStrategy::Gather, gather_relative_cost, true),
857                    (
858                        SparseStrategy::MaskAllFieldsWithNullsSkipOne,
859                        fields.len() - 1,
860                        fields.len() == logical_nulls.len(),
861                    ),
862                    (
863                        SparseStrategy::MaskSkipWithoutNulls,
864                        logical_nulls.len(),
865                        true,
866                    ),
867                    (
868                        SparseStrategy::MaskSkipFullyNull,
869                        fields.len() - fully_null_count,
870                        true,
871                    ),
872                ];
873
874                let (strategy, _, _) = strategies
875                    .iter()
876                    .filter(|(_, _, applicable)| *applicable)
877                    .min_by_key(|(_, cost, _)| cost)
878                    .unwrap();
879
880                match strategy {
881                    SparseStrategy::Gather => self.gather_nulls(logical_nulls),
882                    SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
883                        self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
884                    }
885                    SparseStrategy::MaskSkipWithoutNulls => {
886                        self.mask_sparse_skip_without_nulls(logical_nulls)
887                    }
888                    SparseStrategy::MaskSkipFullyNull => {
889                        self.mask_sparse_skip_fully_null(logical_nulls)
890                    }
891                }
892            }
893        };
894
895        let null_buffer = NullBuffer::from(boolean_buffer);
896
897        if null_buffer.null_count() > 0 {
898            Some(null_buffer)
899        } else {
900            None
901        }
902    }
903
904    fn is_nullable(&self) -> bool {
905        self.fields
906            .iter()
907            .flatten()
908            .any(|field| field.is_nullable())
909    }
910
911    fn get_buffer_memory_size(&self) -> usize {
912        let mut sum = self.type_ids.inner().capacity();
913        if let Some(o) = self.offsets.as_ref() {
914            sum += o.inner().capacity()
915        }
916        self.fields
917            .iter()
918            .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
919            .sum::<usize>()
920            + sum
921    }
922
923    fn get_array_memory_size(&self) -> usize {
924        let mut sum = self.type_ids.inner().capacity();
925        if let Some(o) = self.offsets.as_ref() {
926            sum += o.inner().capacity()
927        }
928        std::mem::size_of::<Self>()
929            + self
930                .fields
931                .iter()
932                .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
933                .sum::<usize>()
934            + sum
935    }
936}
937
938impl std::fmt::Debug for UnionArray {
939    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
940        let header = if self.is_dense() {
941            "UnionArray(Dense)\n["
942        } else {
943            "UnionArray(Sparse)\n["
944        };
945        writeln!(f, "{header}")?;
946
947        writeln!(f, "-- type id buffer:")?;
948        writeln!(f, "{:?}", self.type_ids)?;
949
950        if let Some(offsets) = &self.offsets {
951            writeln!(f, "-- offsets buffer:")?;
952            writeln!(f, "{offsets:?}")?;
953        }
954
955        let fields = match self.data_type() {
956            DataType::Union(fields, _) => fields,
957            _ => unreachable!(),
958        };
959
960        for (type_id, field) in fields.iter() {
961            let child = self.child(type_id);
962            writeln!(
963                f,
964                "-- child {}: \"{}\" ({:?})",
965                type_id,
966                field.name(),
967                field.data_type()
968            )?;
969            std::fmt::Debug::fmt(child, f)?;
970            writeln!(f)?;
971        }
972        writeln!(f, "]")
973    }
974}
975
976/// How to compute the logical nulls of a sparse union. All strategies return the same result.
977/// Those starting with Mask perform bitwise masking for each chunk of 64 values, including
978/// computing expensive selection masks of fields: which fields masks must be computed is the
979/// difference between them
980enum SparseStrategy {
981    /// Gather individual bits from the null buffer of the selected field
982    Gather,
983    /// All fields contains nulls, so we can skip the selection mask computation of one field by negating the others
984    MaskAllFieldsWithNullsSkipOne,
985    /// Skip the selection mask computation of the fields without nulls
986    MaskSkipWithoutNulls,
987    /// Skip the selection mask computation of the fully nulls fields
988    MaskSkipFullyNull,
989}
990
991#[derive(Copy, Clone)]
992#[repr(usize)]
993enum Mask {
994    Zero = 0,
995    // false positive, see https://github.com/rust-lang/rust-clippy/issues/8043
996    #[allow(clippy::enum_clike_unportable_variant)]
997    Max = usize::MAX,
998}
999
1000fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
1001    type_ids_chunk
1002        .iter()
1003        .copied()
1004        .enumerate()
1005        .fold(0, |packed, (bit_idx, v)| {
1006            packed | (((v == type_id) as u64) << bit_idx)
1007        })
1008}
1009
1010/// Returns a bitmask where bits indicate if any id from `without_nulls_ids` exist in `type_ids_chunk`.
1011fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1012    without_nulls_ids
1013        .iter()
1014        .fold(0, |fully_valid_selected, field_type_id| {
1015            fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1016        })
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021    use super::*;
1022    use std::collections::HashSet;
1023
1024    use crate::array::Int8Type;
1025    use crate::builder::UnionBuilder;
1026    use crate::cast::AsArray;
1027    use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1028    use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1029    use crate::{Int8Array, RecordBatch};
1030    use arrow_buffer::Buffer;
1031    use arrow_schema::{Field, Schema};
1032
1033    #[test]
1034    fn test_dense_i32() {
1035        let mut builder = UnionBuilder::new_dense();
1036        builder.append::<Int32Type>("a", 1).unwrap();
1037        builder.append::<Int32Type>("b", 2).unwrap();
1038        builder.append::<Int32Type>("c", 3).unwrap();
1039        builder.append::<Int32Type>("a", 4).unwrap();
1040        builder.append::<Int32Type>("c", 5).unwrap();
1041        builder.append::<Int32Type>("a", 6).unwrap();
1042        builder.append::<Int32Type>("b", 7).unwrap();
1043        let union = builder.build().unwrap();
1044
1045        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1046        let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1047        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1048
1049        // Check type ids
1050        assert_eq!(*union.type_ids(), expected_type_ids);
1051        for (i, id) in expected_type_ids.iter().enumerate() {
1052            assert_eq!(id, &union.type_id(i));
1053        }
1054
1055        // Check offsets
1056        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1057        for (i, id) in expected_offsets.iter().enumerate() {
1058            assert_eq!(union.value_offset(i), *id as usize);
1059        }
1060
1061        // Check data
1062        assert_eq!(
1063            *union.child(0).as_primitive::<Int32Type>().values(),
1064            [1_i32, 4, 6]
1065        );
1066        assert_eq!(
1067            *union.child(1).as_primitive::<Int32Type>().values(),
1068            [2_i32, 7]
1069        );
1070        assert_eq!(
1071            *union.child(2).as_primitive::<Int32Type>().values(),
1072            [3_i32, 5]
1073        );
1074
1075        assert_eq!(expected_array_values.len(), union.len());
1076        for (i, expected_value) in expected_array_values.iter().enumerate() {
1077            assert!(!union.is_null(i));
1078            let slot = union.value(i);
1079            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1080            assert_eq!(slot.len(), 1);
1081            let value = slot.value(0);
1082            assert_eq!(expected_value, &value);
1083        }
1084    }
1085
1086    #[test]
1087    fn slice_union_array_single_field() {
1088        // Dense Union
1089        // [1, null, 3, null, 4]
1090        let union_array = {
1091            let mut builder = UnionBuilder::new_dense();
1092            builder.append::<Int32Type>("a", 1).unwrap();
1093            builder.append_null::<Int32Type>("a").unwrap();
1094            builder.append::<Int32Type>("a", 3).unwrap();
1095            builder.append_null::<Int32Type>("a").unwrap();
1096            builder.append::<Int32Type>("a", 4).unwrap();
1097            builder.build().unwrap()
1098        };
1099
1100        // [null, 3, null]
1101        let union_slice = union_array.slice(1, 3);
1102        let logical_nulls = union_slice.logical_nulls().unwrap();
1103
1104        assert_eq!(logical_nulls.len(), 3);
1105        assert!(logical_nulls.is_null(0));
1106        assert!(logical_nulls.is_valid(1));
1107        assert!(logical_nulls.is_null(2));
1108    }
1109
1110    #[test]
1111    #[cfg_attr(miri, ignore)]
1112    fn test_dense_i32_large() {
1113        let mut builder = UnionBuilder::new_dense();
1114
1115        let expected_type_ids = vec![0_i8; 1024];
1116        let expected_offsets: Vec<_> = (0..1024).collect();
1117        let expected_array_values: Vec<_> = (1..=1024).collect();
1118
1119        expected_array_values
1120            .iter()
1121            .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1122
1123        let union = builder.build().unwrap();
1124
1125        // Check type ids
1126        assert_eq!(*union.type_ids(), expected_type_ids);
1127        for (i, id) in expected_type_ids.iter().enumerate() {
1128            assert_eq!(id, &union.type_id(i));
1129        }
1130
1131        // Check offsets
1132        assert_eq!(*union.offsets().unwrap(), expected_offsets);
1133        for (i, id) in expected_offsets.iter().enumerate() {
1134            assert_eq!(union.value_offset(i), *id as usize);
1135        }
1136
1137        for (i, expected_value) in expected_array_values.iter().enumerate() {
1138            assert!(!union.is_null(i));
1139            let slot = union.value(i);
1140            let slot = slot.as_primitive::<Int32Type>();
1141            assert_eq!(slot.len(), 1);
1142            let value = slot.value(0);
1143            assert_eq!(expected_value, &value);
1144        }
1145    }
1146
1147    #[test]
1148    fn test_dense_mixed() {
1149        let mut builder = UnionBuilder::new_dense();
1150        builder.append::<Int32Type>("a", 1).unwrap();
1151        builder.append::<Int64Type>("c", 3).unwrap();
1152        builder.append::<Int32Type>("a", 4).unwrap();
1153        builder.append::<Int64Type>("c", 5).unwrap();
1154        builder.append::<Int32Type>("a", 6).unwrap();
1155        let union = builder.build().unwrap();
1156
1157        assert_eq!(5, union.len());
1158        for i in 0..union.len() {
1159            let slot = union.value(i);
1160            assert!(!union.is_null(i));
1161            match i {
1162                0 => {
1163                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1164                    assert_eq!(slot.len(), 1);
1165                    let value = slot.value(0);
1166                    assert_eq!(1_i32, value);
1167                }
1168                1 => {
1169                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1170                    assert_eq!(slot.len(), 1);
1171                    let value = slot.value(0);
1172                    assert_eq!(3_i64, value);
1173                }
1174                2 => {
1175                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1176                    assert_eq!(slot.len(), 1);
1177                    let value = slot.value(0);
1178                    assert_eq!(4_i32, value);
1179                }
1180                3 => {
1181                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1182                    assert_eq!(slot.len(), 1);
1183                    let value = slot.value(0);
1184                    assert_eq!(5_i64, value);
1185                }
1186                4 => {
1187                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1188                    assert_eq!(slot.len(), 1);
1189                    let value = slot.value(0);
1190                    assert_eq!(6_i32, value);
1191                }
1192                _ => unreachable!(),
1193            }
1194        }
1195    }
1196
1197    #[test]
1198    fn test_dense_mixed_with_nulls() {
1199        let mut builder = UnionBuilder::new_dense();
1200        builder.append::<Int32Type>("a", 1).unwrap();
1201        builder.append::<Int64Type>("c", 3).unwrap();
1202        builder.append::<Int32Type>("a", 10).unwrap();
1203        builder.append_null::<Int32Type>("a").unwrap();
1204        builder.append::<Int32Type>("a", 6).unwrap();
1205        let union = builder.build().unwrap();
1206
1207        assert_eq!(5, union.len());
1208        for i in 0..union.len() {
1209            let slot = union.value(i);
1210            match i {
1211                0 => {
1212                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1213                    assert!(!slot.is_null(0));
1214                    assert_eq!(slot.len(), 1);
1215                    let value = slot.value(0);
1216                    assert_eq!(1_i32, value);
1217                }
1218                1 => {
1219                    let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1220                    assert!(!slot.is_null(0));
1221                    assert_eq!(slot.len(), 1);
1222                    let value = slot.value(0);
1223                    assert_eq!(3_i64, value);
1224                }
1225                2 => {
1226                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1227                    assert!(!slot.is_null(0));
1228                    assert_eq!(slot.len(), 1);
1229                    let value = slot.value(0);
1230                    assert_eq!(10_i32, value);
1231                }
1232                3 => assert!(slot.is_null(0)),
1233                4 => {
1234                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1235                    assert!(!slot.is_null(0));
1236                    assert_eq!(slot.len(), 1);
1237                    let value = slot.value(0);
1238                    assert_eq!(6_i32, value);
1239                }
1240                _ => unreachable!(),
1241            }
1242        }
1243    }
1244
1245    #[test]
1246    fn test_dense_mixed_with_nulls_and_offset() {
1247        let mut builder = UnionBuilder::new_dense();
1248        builder.append::<Int32Type>("a", 1).unwrap();
1249        builder.append::<Int64Type>("c", 3).unwrap();
1250        builder.append::<Int32Type>("a", 10).unwrap();
1251        builder.append_null::<Int32Type>("a").unwrap();
1252        builder.append::<Int32Type>("a", 6).unwrap();
1253        let union = builder.build().unwrap();
1254
1255        let slice = union.slice(2, 3);
1256        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1257
1258        assert_eq!(3, new_union.len());
1259        for i in 0..new_union.len() {
1260            let slot = new_union.value(i);
1261            match i {
1262                0 => {
1263                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1264                    assert!(!slot.is_null(0));
1265                    assert_eq!(slot.len(), 1);
1266                    let value = slot.value(0);
1267                    assert_eq!(10_i32, value);
1268                }
1269                1 => assert!(slot.is_null(0)),
1270                2 => {
1271                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1272                    assert!(!slot.is_null(0));
1273                    assert_eq!(slot.len(), 1);
1274                    let value = slot.value(0);
1275                    assert_eq!(6_i32, value);
1276                }
1277                _ => unreachable!(),
1278            }
1279        }
1280    }
1281
1282    #[test]
1283    fn test_dense_mixed_with_str() {
1284        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1285        let int_array = Int32Array::from(vec![5, 6]);
1286        let float_array = Float64Array::from(vec![10.0]);
1287
1288        let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1289        let offsets = [0, 0, 1, 0, 2, 1]
1290            .into_iter()
1291            .collect::<ScalarBuffer<i32>>();
1292
1293        let fields = [
1294            (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1295            (1, Arc::new(Field::new("B", DataType::Int32, false))),
1296            (2, Arc::new(Field::new("C", DataType::Float64, false))),
1297        ]
1298        .into_iter()
1299        .collect::<UnionFields>();
1300        let children = [
1301            Arc::new(string_array) as Arc<dyn Array>,
1302            Arc::new(int_array),
1303            Arc::new(float_array),
1304        ]
1305        .into_iter()
1306        .collect();
1307        let array =
1308            UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1309
1310        // Check type ids
1311        assert_eq!(*array.type_ids(), type_ids);
1312        for (i, id) in type_ids.iter().enumerate() {
1313            assert_eq!(id, &array.type_id(i));
1314        }
1315
1316        // Check offsets
1317        assert_eq!(*array.offsets().unwrap(), offsets);
1318        for (i, id) in offsets.iter().enumerate() {
1319            assert_eq!(*id as usize, array.value_offset(i));
1320        }
1321
1322        // Check values
1323        assert_eq!(6, array.len());
1324
1325        let slot = array.value(0);
1326        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1327        assert_eq!(5, value);
1328
1329        let slot = array.value(1);
1330        let value = slot
1331            .as_any()
1332            .downcast_ref::<StringArray>()
1333            .unwrap()
1334            .value(0);
1335        assert_eq!("foo", value);
1336
1337        let slot = array.value(2);
1338        let value = slot
1339            .as_any()
1340            .downcast_ref::<StringArray>()
1341            .unwrap()
1342            .value(0);
1343        assert_eq!("bar", value);
1344
1345        let slot = array.value(3);
1346        let value = slot
1347            .as_any()
1348            .downcast_ref::<Float64Array>()
1349            .unwrap()
1350            .value(0);
1351        assert_eq!(10.0, value);
1352
1353        let slot = array.value(4);
1354        let value = slot
1355            .as_any()
1356            .downcast_ref::<StringArray>()
1357            .unwrap()
1358            .value(0);
1359        assert_eq!("baz", value);
1360
1361        let slot = array.value(5);
1362        let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1363        assert_eq!(6, value);
1364    }
1365
1366    #[test]
1367    fn test_sparse_i32() {
1368        let mut builder = UnionBuilder::new_sparse();
1369        builder.append::<Int32Type>("a", 1).unwrap();
1370        builder.append::<Int32Type>("b", 2).unwrap();
1371        builder.append::<Int32Type>("c", 3).unwrap();
1372        builder.append::<Int32Type>("a", 4).unwrap();
1373        builder.append::<Int32Type>("c", 5).unwrap();
1374        builder.append::<Int32Type>("a", 6).unwrap();
1375        builder.append::<Int32Type>("b", 7).unwrap();
1376        let union = builder.build().unwrap();
1377
1378        let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1379        let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1380
1381        // Check type ids
1382        assert_eq!(*union.type_ids(), expected_type_ids);
1383        for (i, id) in expected_type_ids.iter().enumerate() {
1384            assert_eq!(id, &union.type_id(i));
1385        }
1386
1387        // Check offsets, sparse union should only have a single buffer
1388        assert!(union.offsets().is_none());
1389
1390        // Check data
1391        assert_eq!(
1392            *union.child(0).as_primitive::<Int32Type>().values(),
1393            [1_i32, 0, 0, 4, 0, 6, 0],
1394        );
1395        assert_eq!(
1396            *union.child(1).as_primitive::<Int32Type>().values(),
1397            [0_i32, 2_i32, 0, 0, 0, 0, 7]
1398        );
1399        assert_eq!(
1400            *union.child(2).as_primitive::<Int32Type>().values(),
1401            [0_i32, 0, 3_i32, 0, 5, 0, 0]
1402        );
1403
1404        assert_eq!(expected_array_values.len(), union.len());
1405        for (i, expected_value) in expected_array_values.iter().enumerate() {
1406            assert!(!union.is_null(i));
1407            let slot = union.value(i);
1408            let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1409            assert_eq!(slot.len(), 1);
1410            let value = slot.value(0);
1411            assert_eq!(expected_value, &value);
1412        }
1413    }
1414
1415    #[test]
1416    fn test_sparse_mixed() {
1417        let mut builder = UnionBuilder::new_sparse();
1418        builder.append::<Int32Type>("a", 1).unwrap();
1419        builder.append::<Float64Type>("c", 3.0).unwrap();
1420        builder.append::<Int32Type>("a", 4).unwrap();
1421        builder.append::<Float64Type>("c", 5.0).unwrap();
1422        builder.append::<Int32Type>("a", 6).unwrap();
1423        let union = builder.build().unwrap();
1424
1425        let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1426
1427        // Check type ids
1428        assert_eq!(*union.type_ids(), expected_type_ids);
1429        for (i, id) in expected_type_ids.iter().enumerate() {
1430            assert_eq!(id, &union.type_id(i));
1431        }
1432
1433        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1434        assert!(union.offsets().is_none());
1435
1436        for i in 0..union.len() {
1437            let slot = union.value(i);
1438            assert!(!union.is_null(i));
1439            match i {
1440                0 => {
1441                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1442                    assert_eq!(slot.len(), 1);
1443                    let value = slot.value(0);
1444                    assert_eq!(1_i32, value);
1445                }
1446                1 => {
1447                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1448                    assert_eq!(slot.len(), 1);
1449                    let value = slot.value(0);
1450                    assert_eq!(value, 3_f64);
1451                }
1452                2 => {
1453                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1454                    assert_eq!(slot.len(), 1);
1455                    let value = slot.value(0);
1456                    assert_eq!(4_i32, value);
1457                }
1458                3 => {
1459                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1460                    assert_eq!(slot.len(), 1);
1461                    let value = slot.value(0);
1462                    assert_eq!(5_f64, value);
1463                }
1464                4 => {
1465                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1466                    assert_eq!(slot.len(), 1);
1467                    let value = slot.value(0);
1468                    assert_eq!(6_i32, value);
1469                }
1470                _ => unreachable!(),
1471            }
1472        }
1473    }
1474
1475    #[test]
1476    fn test_sparse_mixed_with_nulls() {
1477        let mut builder = UnionBuilder::new_sparse();
1478        builder.append::<Int32Type>("a", 1).unwrap();
1479        builder.append_null::<Int32Type>("a").unwrap();
1480        builder.append::<Float64Type>("c", 3.0).unwrap();
1481        builder.append::<Int32Type>("a", 4).unwrap();
1482        let union = builder.build().unwrap();
1483
1484        let expected_type_ids = vec![0_i8, 0, 1, 0];
1485
1486        // Check type ids
1487        assert_eq!(*union.type_ids(), expected_type_ids);
1488        for (i, id) in expected_type_ids.iter().enumerate() {
1489            assert_eq!(id, &union.type_id(i));
1490        }
1491
1492        // Check offsets, sparse union should only have a single buffer, i.e. no offsets
1493        assert!(union.offsets().is_none());
1494
1495        for i in 0..union.len() {
1496            let slot = union.value(i);
1497            match i {
1498                0 => {
1499                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1500                    assert!(!slot.is_null(0));
1501                    assert_eq!(slot.len(), 1);
1502                    let value = slot.value(0);
1503                    assert_eq!(1_i32, value);
1504                }
1505                1 => assert!(slot.is_null(0)),
1506                2 => {
1507                    let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1508                    assert!(!slot.is_null(0));
1509                    assert_eq!(slot.len(), 1);
1510                    let value = slot.value(0);
1511                    assert_eq!(value, 3_f64);
1512                }
1513                3 => {
1514                    let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1515                    assert!(!slot.is_null(0));
1516                    assert_eq!(slot.len(), 1);
1517                    let value = slot.value(0);
1518                    assert_eq!(4_i32, value);
1519                }
1520                _ => unreachable!(),
1521            }
1522        }
1523    }
1524
1525    #[test]
1526    fn test_sparse_mixed_with_nulls_and_offset() {
1527        let mut builder = UnionBuilder::new_sparse();
1528        builder.append::<Int32Type>("a", 1).unwrap();
1529        builder.append_null::<Int32Type>("a").unwrap();
1530        builder.append::<Float64Type>("c", 3.0).unwrap();
1531        builder.append_null::<Float64Type>("c").unwrap();
1532        builder.append::<Int32Type>("a", 4).unwrap();
1533        let union = builder.build().unwrap();
1534
1535        let slice = union.slice(1, 4);
1536        let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1537
1538        assert_eq!(4, new_union.len());
1539        for i in 0..new_union.len() {
1540            let slot = new_union.value(i);
1541            match i {
1542                0 => assert!(slot.is_null(0)),
1543                1 => {
1544                    let slot = slot.as_primitive::<Float64Type>();
1545                    assert!(!slot.is_null(0));
1546                    assert_eq!(slot.len(), 1);
1547                    let value = slot.value(0);
1548                    assert_eq!(value, 3_f64);
1549                }
1550                2 => assert!(slot.is_null(0)),
1551                3 => {
1552                    let slot = slot.as_primitive::<Int32Type>();
1553                    assert!(!slot.is_null(0));
1554                    assert_eq!(slot.len(), 1);
1555                    let value = slot.value(0);
1556                    assert_eq!(4_i32, value);
1557                }
1558                _ => unreachable!(),
1559            }
1560        }
1561    }
1562
1563    fn test_union_validity(union_array: &UnionArray) {
1564        assert_eq!(union_array.null_count(), 0);
1565
1566        for i in 0..union_array.len() {
1567            assert!(!union_array.is_null(i));
1568            assert!(union_array.is_valid(i));
1569        }
1570    }
1571
1572    #[test]
1573    fn test_union_array_validity() {
1574        let mut builder = UnionBuilder::new_sparse();
1575        builder.append::<Int32Type>("a", 1).unwrap();
1576        builder.append_null::<Int32Type>("a").unwrap();
1577        builder.append::<Float64Type>("c", 3.0).unwrap();
1578        builder.append_null::<Float64Type>("c").unwrap();
1579        builder.append::<Int32Type>("a", 4).unwrap();
1580        let union = builder.build().unwrap();
1581
1582        test_union_validity(&union);
1583
1584        let mut builder = UnionBuilder::new_dense();
1585        builder.append::<Int32Type>("a", 1).unwrap();
1586        builder.append_null::<Int32Type>("a").unwrap();
1587        builder.append::<Float64Type>("c", 3.0).unwrap();
1588        builder.append_null::<Float64Type>("c").unwrap();
1589        builder.append::<Int32Type>("a", 4).unwrap();
1590        let union = builder.build().unwrap();
1591
1592        test_union_validity(&union);
1593    }
1594
1595    #[test]
1596    fn test_type_check() {
1597        let mut builder = UnionBuilder::new_sparse();
1598        builder.append::<Float32Type>("a", 1.0).unwrap();
1599        let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1600        assert!(
1601            err.contains(
1602                "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1603            ),
1604            "{}",
1605            err
1606        );
1607    }
1608
1609    #[test]
1610    fn slice_union_array() {
1611        // [1, null, 3.0, null, 4]
1612        fn create_union(mut builder: UnionBuilder) -> UnionArray {
1613            builder.append::<Int32Type>("a", 1).unwrap();
1614            builder.append_null::<Int32Type>("a").unwrap();
1615            builder.append::<Float64Type>("c", 3.0).unwrap();
1616            builder.append_null::<Float64Type>("c").unwrap();
1617            builder.append::<Int32Type>("a", 4).unwrap();
1618            builder.build().unwrap()
1619        }
1620
1621        fn create_batch(union: UnionArray) -> RecordBatch {
1622            let schema = Schema::new(vec![Field::new(
1623                "struct_array",
1624                union.data_type().clone(),
1625                true,
1626            )]);
1627
1628            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1629        }
1630
1631        fn test_slice_union(record_batch_slice: RecordBatch) {
1632            let union_slice = record_batch_slice
1633                .column(0)
1634                .as_any()
1635                .downcast_ref::<UnionArray>()
1636                .unwrap();
1637
1638            assert_eq!(union_slice.type_id(0), 0);
1639            assert_eq!(union_slice.type_id(1), 1);
1640            assert_eq!(union_slice.type_id(2), 1);
1641
1642            let slot = union_slice.value(0);
1643            let array = slot.as_primitive::<Int32Type>();
1644            assert_eq!(array.len(), 1);
1645            assert!(array.is_null(0));
1646
1647            let slot = union_slice.value(1);
1648            let array = slot.as_primitive::<Float64Type>();
1649            assert_eq!(array.len(), 1);
1650            assert!(array.is_valid(0));
1651            assert_eq!(array.value(0), 3.0);
1652
1653            let slot = union_slice.value(2);
1654            let array = slot.as_primitive::<Float64Type>();
1655            assert_eq!(array.len(), 1);
1656            assert!(array.is_null(0));
1657        }
1658
1659        // Sparse Union
1660        let builder = UnionBuilder::new_sparse();
1661        let record_batch = create_batch(create_union(builder));
1662        // [null, 3.0, null]
1663        let record_batch_slice = record_batch.slice(1, 3);
1664        test_slice_union(record_batch_slice);
1665
1666        // Dense Union
1667        let builder = UnionBuilder::new_dense();
1668        let record_batch = create_batch(create_union(builder));
1669        // [null, 3.0, null]
1670        let record_batch_slice = record_batch.slice(1, 3);
1671        test_slice_union(record_batch_slice);
1672    }
1673
1674    #[test]
1675    fn test_custom_type_ids() {
1676        let data_type = DataType::Union(
1677            UnionFields::new(
1678                vec![8, 4, 9],
1679                vec![
1680                    Field::new("strings", DataType::Utf8, false),
1681                    Field::new("integers", DataType::Int32, false),
1682                    Field::new("floats", DataType::Float64, false),
1683                ],
1684            ),
1685            UnionMode::Dense,
1686        );
1687
1688        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1689        let int_array = Int32Array::from(vec![5, 6, 4]);
1690        let float_array = Float64Array::from(vec![10.0]);
1691
1692        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1693        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1694
1695        let data = ArrayData::builder(data_type)
1696            .len(7)
1697            .buffers(vec![type_ids, value_offsets])
1698            .child_data(vec![
1699                string_array.into_data(),
1700                int_array.into_data(),
1701                float_array.into_data(),
1702            ])
1703            .build()
1704            .unwrap();
1705
1706        let array = UnionArray::from(data);
1707
1708        let v = array.value(0);
1709        assert_eq!(v.data_type(), &DataType::Int32);
1710        assert_eq!(v.len(), 1);
1711        assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1712
1713        let v = array.value(1);
1714        assert_eq!(v.data_type(), &DataType::Utf8);
1715        assert_eq!(v.len(), 1);
1716        assert_eq!(v.as_string::<i32>().value(0), "foo");
1717
1718        let v = array.value(2);
1719        assert_eq!(v.data_type(), &DataType::Int32);
1720        assert_eq!(v.len(), 1);
1721        assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1722
1723        let v = array.value(3);
1724        assert_eq!(v.data_type(), &DataType::Utf8);
1725        assert_eq!(v.len(), 1);
1726        assert_eq!(v.as_string::<i32>().value(0), "bar");
1727
1728        let v = array.value(4);
1729        assert_eq!(v.data_type(), &DataType::Float64);
1730        assert_eq!(v.len(), 1);
1731        assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1732
1733        let v = array.value(5);
1734        assert_eq!(v.data_type(), &DataType::Int32);
1735        assert_eq!(v.len(), 1);
1736        assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1737
1738        let v = array.value(6);
1739        assert_eq!(v.data_type(), &DataType::Utf8);
1740        assert_eq!(v.len(), 1);
1741        assert_eq!(v.as_string::<i32>().value(0), "baz");
1742    }
1743
1744    #[test]
1745    fn into_parts() {
1746        let mut builder = UnionBuilder::new_dense();
1747        builder.append::<Int32Type>("a", 1).unwrap();
1748        builder.append::<Int8Type>("b", 2).unwrap();
1749        builder.append::<Int32Type>("a", 3).unwrap();
1750        let dense_union = builder.build().unwrap();
1751
1752        let field = [
1753            &Arc::new(Field::new("a", DataType::Int32, false)),
1754            &Arc::new(Field::new("b", DataType::Int8, false)),
1755        ];
1756        let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1757        assert_eq!(
1758            union_fields
1759                .iter()
1760                .map(|(_, field)| field)
1761                .collect::<Vec<_>>(),
1762            field
1763        );
1764        assert_eq!(type_ids, [0, 1, 0]);
1765        assert!(offsets.is_some());
1766        assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1767
1768        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1769        assert!(result.is_ok());
1770        assert_eq!(result.unwrap().len(), 3);
1771
1772        let mut builder = UnionBuilder::new_sparse();
1773        builder.append::<Int32Type>("a", 1).unwrap();
1774        builder.append::<Int8Type>("b", 2).unwrap();
1775        builder.append::<Int32Type>("a", 3).unwrap();
1776        let sparse_union = builder.build().unwrap();
1777
1778        let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1779        assert_eq!(type_ids, [0, 1, 0]);
1780        assert!(offsets.is_none());
1781
1782        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1783        assert!(result.is_ok());
1784        assert_eq!(result.unwrap().len(), 3);
1785    }
1786
1787    #[test]
1788    fn into_parts_custom_type_ids() {
1789        let set_field_type_ids: [i8; 3] = [8, 4, 9];
1790        let data_type = DataType::Union(
1791            UnionFields::new(
1792                set_field_type_ids,
1793                [
1794                    Field::new("strings", DataType::Utf8, false),
1795                    Field::new("integers", DataType::Int32, false),
1796                    Field::new("floats", DataType::Float64, false),
1797                ],
1798            ),
1799            UnionMode::Dense,
1800        );
1801        let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1802        let int_array = Int32Array::from(vec![5, 6, 4]);
1803        let float_array = Float64Array::from(vec![10.0]);
1804        let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1805        let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1806        let data = ArrayData::builder(data_type)
1807            .len(7)
1808            .buffers(vec![type_ids, value_offsets])
1809            .child_data(vec![
1810                string_array.into_data(),
1811                int_array.into_data(),
1812                float_array.into_data(),
1813            ])
1814            .build()
1815            .unwrap();
1816        let array = UnionArray::from(data);
1817
1818        let (union_fields, type_ids, offsets, children) = array.into_parts();
1819        assert_eq!(
1820            type_ids.iter().collect::<HashSet<_>>(),
1821            set_field_type_ids.iter().collect::<HashSet<_>>()
1822        );
1823        let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1824        assert!(result.is_ok());
1825        let array = result.unwrap();
1826        assert_eq!(array.len(), 7);
1827    }
1828
1829    #[test]
1830    fn test_invalid() {
1831        let fields = UnionFields::new(
1832            [3, 2],
1833            [
1834                Field::new("a", DataType::Utf8, false),
1835                Field::new("b", DataType::Utf8, false),
1836            ],
1837        );
1838        let children = vec![
1839            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1840            Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1841        ];
1842
1843        let type_ids = vec![3, 3, 2].into();
1844        let err =
1845            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1846        assert_eq!(
1847            err.to_string(),
1848            "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1849        );
1850
1851        let type_ids = vec![1, 2].into();
1852        let err =
1853            UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1854        assert_eq!(
1855            err.to_string(),
1856            "Invalid argument error: Type Ids values must match one of the field type ids"
1857        );
1858
1859        let type_ids = vec![7, 2].into();
1860        let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1861        assert_eq!(
1862            err.to_string(),
1863            "Invalid argument error: Type Ids values must match one of the field type ids"
1864        );
1865
1866        let children = vec![
1867            Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1868            Arc::new(StringArray::from_iter_values(["c"])) as _,
1869        ];
1870        let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1871        let offsets = Some(vec![0, 1, 0].into());
1872        UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1873
1874        let offsets = Some(vec![0, 1, 1].into());
1875        let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1876            .unwrap_err();
1877
1878        assert_eq!(
1879            err.to_string(),
1880            "Invalid argument error: Offsets must be positive and within the length of the Array"
1881        );
1882
1883        let offsets = Some(vec![0, 1].into());
1884        let err =
1885            UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1886
1887        assert_eq!(
1888            err.to_string(),
1889            "Invalid argument error: Type Ids and Offsets lengths must match"
1890        );
1891
1892        let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1893
1894        assert_eq!(
1895            err.to_string(),
1896            "Invalid argument error: Union fields length must match child arrays length"
1897        );
1898    }
1899
1900    #[test]
1901    fn test_logical_nulls_fast_paths() {
1902        // fields.len() <= 1
1903        let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1904
1905        assert_eq!(array.logical_nulls(), None);
1906
1907        let fields = UnionFields::new(
1908            [1, 3],
1909            [
1910                Field::new("a", DataType::Int8, false), // non nullable
1911                Field::new("b", DataType::Int8, false), // non nullable
1912            ],
1913        );
1914        let array = UnionArray::try_new(
1915            fields,
1916            vec![1].into(),
1917            None,
1918            vec![
1919                Arc::new(Int8Array::from_value(5, 1)),
1920                Arc::new(Int8Array::from_value(5, 1)),
1921            ],
1922        )
1923        .unwrap();
1924
1925        assert_eq!(array.logical_nulls(), None);
1926
1927        let nullable_fields = UnionFields::new(
1928            [1, 3],
1929            [
1930                Field::new("a", DataType::Int8, true), // nullable but without nulls
1931                Field::new("b", DataType::Int8, true), // nullable but without nulls
1932            ],
1933        );
1934        let array = UnionArray::try_new(
1935            nullable_fields.clone(),
1936            vec![1, 1].into(),
1937            None,
1938            vec![
1939                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1940                Arc::new(Int8Array::from_value(-5, 2)), // nullable but without nulls
1941            ],
1942        )
1943        .unwrap();
1944
1945        assert_eq!(array.logical_nulls(), None);
1946
1947        let array = UnionArray::try_new(
1948            nullable_fields.clone(),
1949            vec![1, 1].into(),
1950            None,
1951            vec![
1952                // every children is completly null
1953                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1954                Arc::new(Int8Array::new_null(2)), // all null, same len as it's parent
1955            ],
1956        )
1957        .unwrap();
1958
1959        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1960
1961        let array = UnionArray::try_new(
1962            nullable_fields.clone(),
1963            vec![1, 1].into(),
1964            Some(vec![0, 1].into()),
1965            vec![
1966                // every children is completly null
1967                Arc::new(Int8Array::new_null(3)), // bigger that parent
1968                Arc::new(Int8Array::new_null(3)), // bigger that parent
1969            ],
1970        )
1971        .unwrap();
1972
1973        assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1974    }
1975
1976    #[test]
1977    fn test_dense_union_logical_nulls_gather() {
1978        // union of [{A=1}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
1979        let int_array = Int32Array::from(vec![1, 2]);
1980        let float_array = Float64Array::from(vec![Some(3.2), None]);
1981        let str_array = StringArray::new_null(1);
1982        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
1983        let offsets = [0, 1, 0, 1, 0, 0]
1984            .into_iter()
1985            .collect::<ScalarBuffer<i32>>();
1986
1987        let children = vec![
1988            Arc::new(int_array) as Arc<dyn Array>,
1989            Arc::new(float_array),
1990            Arc::new(str_array),
1991        ];
1992
1993        let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
1994
1995        let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
1996
1997        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
1998        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
1999    }
2000
2001    #[test]
2002    fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
2003        let fields: UnionFields = [
2004            (1, Arc::new(Field::new("A", DataType::Int32, true))),
2005            (3, Arc::new(Field::new("B", DataType::Float64, true))),
2006        ]
2007        .into_iter()
2008        .collect();
2009
2010        // union of [{A=}, {A=}, {B=3.2}, {B=}]
2011        let int_array = Int32Array::new_null(4);
2012        let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
2013        let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
2014
2015        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2016
2017        let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
2018
2019        let expected = BooleanBuffer::from(vec![false, false, true, false]);
2020
2021        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2022        assert_eq!(
2023            expected,
2024            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2025        );
2026
2027        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2028        let len = 2 * 64 + 32;
2029
2030        let int_array = Int32Array::new_null(len);
2031        let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
2032        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2033
2034        let array = UnionArray::try_new(
2035            fields,
2036            type_ids,
2037            None,
2038            vec![Arc::new(int_array), Arc::new(float_array)],
2039        )
2040        .unwrap();
2041
2042        let expected =
2043            BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2044
2045        assert_eq!(array.len(), len);
2046        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2047        assert_eq!(
2048            expected,
2049            array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2050        );
2051    }
2052
2053    #[test]
2054    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2055        // union of [{A=2}, {A=2}, {B=3.2}, {B=}, {C=}, {C=}]
2056        let int_array = Int32Array::from_value(2, 6);
2057        let float_array = Float64Array::from_value(4.2, 6);
2058        let str_array = StringArray::new_null(6);
2059        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2060
2061        let children = vec![
2062            Arc::new(int_array) as Arc<dyn Array>,
2063            Arc::new(float_array),
2064            Arc::new(str_array),
2065        ];
2066
2067        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2068
2069        let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2070
2071        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2072        assert_eq!(
2073            expected,
2074            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2075        );
2076
2077        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2078        let len = 2 * 64 + 32;
2079
2080        let int_array = Int32Array::from_value(2, len);
2081        let float_array = Float64Array::from_value(4.2, len);
2082        let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2083        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2084
2085        let children = vec![
2086            Arc::new(int_array) as Arc<dyn Array>,
2087            Arc::new(float_array),
2088            Arc::new(str_array),
2089        ];
2090
2091        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2092
2093        let expected = BooleanBuffer::from_iter(
2094            [true, true, true, true, false, true]
2095                .into_iter()
2096                .cycle()
2097                .take(len),
2098        );
2099
2100        assert_eq!(array.len(), len);
2101        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2102        assert_eq!(
2103            expected,
2104            array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2105        );
2106    }
2107
2108    #[test]
2109    fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2110        // union of [{A=}, {A=}, {B=4.2}, {B=4.2}, {C=}, {C=}]
2111        let int_array = Int32Array::new_null(6);
2112        let float_array = Float64Array::from_value(4.2, 6);
2113        let str_array = StringArray::new_null(6);
2114        let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2115
2116        let children = vec![
2117            Arc::new(int_array) as Arc<dyn Array>,
2118            Arc::new(float_array),
2119            Arc::new(str_array),
2120        ];
2121
2122        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2123
2124        let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2125
2126        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2127        assert_eq!(
2128            expected,
2129            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2130        );
2131
2132        //like above, but repeated to genereate two exact bitmasks and a non empty remainder
2133        let len = 2 * 64 + 32;
2134
2135        let int_array = Int32Array::new_null(len);
2136        let float_array = Float64Array::from_value(4.2, len);
2137        let str_array = StringArray::new_null(len);
2138        let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2139
2140        let children = vec![
2141            Arc::new(int_array) as Arc<dyn Array>,
2142            Arc::new(float_array),
2143            Arc::new(str_array),
2144        ];
2145
2146        let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2147
2148        let expected = BooleanBuffer::from_iter(
2149            [false, false, true, true, false, false]
2150                .into_iter()
2151                .cycle()
2152                .take(len),
2153        );
2154
2155        assert_eq!(array.len(), len);
2156        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2157        assert_eq!(
2158            expected,
2159            array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2160        );
2161    }
2162
2163    #[test]
2164    fn test_sparse_union_logical_nulls_gather() {
2165        let n_fields = 50;
2166
2167        let non_null = Int32Array::from_value(2, 4);
2168        let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2169        let fully_null = Int32Array::new_null(4);
2170
2171        let array = UnionArray::try_new(
2172            (1..)
2173                .step_by(2)
2174                .map(|i| {
2175                    (
2176                        i,
2177                        Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2178                    )
2179                })
2180                .take(n_fields)
2181                .collect(),
2182            vec![1, 3, 3, 5].into(),
2183            None,
2184            [
2185                Arc::new(non_null) as ArrayRef,
2186                Arc::new(mixed),
2187                Arc::new(fully_null),
2188            ]
2189            .into_iter()
2190            .cycle()
2191            .take(n_fields)
2192            .collect(),
2193        )
2194        .unwrap();
2195
2196        let expected = BooleanBuffer::from(vec![true, false, true, false]);
2197
2198        assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2199        assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2200    }
2201
2202    fn union_fields() -> UnionFields {
2203        [
2204            (1, Arc::new(Field::new("A", DataType::Int32, true))),
2205            (3, Arc::new(Field::new("B", DataType::Float64, true))),
2206            (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2207        ]
2208        .into_iter()
2209        .collect()
2210    }
2211
2212    #[test]
2213    fn test_is_nullable() {
2214        assert!(!create_union_array(false, false).is_nullable());
2215        assert!(create_union_array(true, false).is_nullable());
2216        assert!(create_union_array(false, true).is_nullable());
2217        assert!(create_union_array(true, true).is_nullable());
2218    }
2219
2220    /// Create a union array with a float and integer field
2221    ///
2222    /// If the `int_nullable` is true, the integer field will have nulls
2223    /// If the `float_nullable` is true, the float field will have nulls
2224    ///
2225    /// Note the `Field` definitions are always declared to be nullable
2226    fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2227        let int_array = if int_nullable {
2228            Int32Array::from(vec![Some(1), None, Some(3)])
2229        } else {
2230            Int32Array::from(vec![1, 2, 3])
2231        };
2232        let float_array = if float_nullable {
2233            Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2234        } else {
2235            Float64Array::from(vec![3.2, 4.2, 5.2])
2236        };
2237        let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2238        let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2239        let union_fields = [
2240            (0, Arc::new(Field::new("A", DataType::Int32, true))),
2241            (1, Arc::new(Field::new("B", DataType::Float64, true))),
2242        ]
2243        .into_iter()
2244        .collect::<UnionFields>();
2245
2246        let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2247
2248        UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2249    }
2250}