1#![allow(clippy::enum_clike_unportable_variant)]
18
19use crate::{Array, ArrayRef, make_array};
20use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21use arrow_buffer::buffer::NullBuffer;
22use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25use std::any::Any;
28use std::collections::HashSet;
29use std::sync::Arc;
30
31#[derive(Clone)]
123pub struct UnionArray {
124 data_type: DataType,
125 type_ids: ScalarBuffer<i8>,
126 offsets: Option<ScalarBuffer<i32>>,
127 fields: Vec<Option<ArrayRef>>,
128}
129
130impl UnionArray {
131 pub unsafe fn new_unchecked(
150 fields: UnionFields,
151 type_ids: ScalarBuffer<i8>,
152 offsets: Option<ScalarBuffer<i32>>,
153 children: Vec<ArrayRef>,
154 ) -> Self {
155 let mode = if offsets.is_some() {
156 UnionMode::Dense
157 } else {
158 UnionMode::Sparse
159 };
160
161 let len = type_ids.len();
162 let builder = ArrayData::builder(DataType::Union(fields, mode))
163 .add_buffer(type_ids.into_inner())
164 .child_data(children.into_iter().map(Array::into_data).collect())
165 .len(len);
166
167 let data = match offsets {
168 Some(offsets) => unsafe { builder.add_buffer(offsets.into_inner()).build_unchecked() },
169 None => unsafe { builder.build_unchecked() },
170 };
171 Self::from(data)
172 }
173
174 pub fn try_new(
178 fields: UnionFields,
179 type_ids: ScalarBuffer<i8>,
180 offsets: Option<ScalarBuffer<i32>>,
181 children: Vec<ArrayRef>,
182 ) -> Result<Self, ArrowError> {
183 if fields.len() != children.len() {
185 return Err(ArrowError::InvalidArgumentError(
186 "Union fields length must match child arrays length".to_string(),
187 ));
188 }
189
190 if let Some(offsets) = &offsets {
191 if offsets.len() != type_ids.len() {
193 return Err(ArrowError::InvalidArgumentError(
194 "Type Ids and Offsets lengths must match".to_string(),
195 ));
196 }
197 } else {
198 for child in &children {
200 if child.len() != type_ids.len() {
201 return Err(ArrowError::InvalidArgumentError(
202 "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203 ));
204 }
205 }
206 }
207
208 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210 let mut array_lens = vec![i32::MIN; max_id + 1];
211 for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212 array_lens[field_id as usize] = cd.len() as i32;
213 }
214
215 for id in &type_ids {
217 match array_lens.get(*id as usize) {
218 Some(x) if *x != i32::MIN => {}
219 _ => {
220 return Err(ArrowError::InvalidArgumentError(
221 "Type Ids values must match one of the field type ids".to_owned(),
222 ));
223 }
224 }
225 }
226
227 if let Some(offsets) = &offsets {
229 let mut iter = type_ids.iter().zip(offsets.iter());
230 if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231 {
232 return Err(ArrowError::InvalidArgumentError(
233 "Offsets must be non-negative and within the length of the Array".to_owned(),
234 ));
235 }
236 }
237
238 let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241 Ok(union_array)
242 }
243
244 pub fn child(&self, type_id: i8) -> &ArrayRef {
251 assert!((type_id as usize) < self.fields.len());
252 let boxed = &self.fields[type_id as usize];
253 boxed.as_ref().expect("invalid type id")
254 }
255
256 pub fn type_id(&self, index: usize) -> i8 {
262 assert!(index < self.type_ids.len());
263 self.type_ids[index]
264 }
265
266 pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268 &self.type_ids
269 }
270
271 pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273 self.offsets.as_ref()
274 }
275
276 pub fn value_offset(&self, index: usize) -> usize {
282 assert!(index < self.len());
283 match &self.offsets {
284 Some(offsets) => offsets[index] as usize,
285 None => self.offset() + index,
286 }
287 }
288
289 pub fn value(&self, i: usize) -> ArrayRef {
297 let type_id = self.type_id(i);
298 let value_offset = self.value_offset(i);
299 let child = self.child(type_id);
300 child.slice(value_offset, 1)
301 }
302
303 pub fn type_names(&self) -> Vec<&str> {
305 match self.data_type() {
306 DataType::Union(fields, _) => fields
307 .iter()
308 .map(|(_, f)| f.name().as_str())
309 .collect::<Vec<&str>>(),
310 _ => unreachable!("Union array's data type is not a union!"),
311 }
312 }
313
314 pub fn fields(&self) -> &UnionFields {
316 match self.data_type() {
317 DataType::Union(fields, _) => fields,
318 _ => unreachable!("Union array's data type is not a union!"),
319 }
320 }
321
322 pub fn is_dense(&self) -> bool {
324 match self.data_type() {
325 DataType::Union(_, mode) => mode == &UnionMode::Dense,
326 _ => unreachable!("Union array's data type is not a union!"),
327 }
328 }
329
330 pub fn slice(&self, offset: usize, length: usize) -> Self {
332 let (offsets, fields) = match self.offsets.as_ref() {
333 Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
335 None => {
337 let fields = self
338 .fields
339 .iter()
340 .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
341 .collect();
342 (None, fields)
343 }
344 };
345
346 Self {
347 data_type: self.data_type.clone(),
348 type_ids: self.type_ids.slice(offset, length),
349 offsets,
350 fields,
351 }
352 }
353
354 #[allow(clippy::type_complexity)]
382 pub fn into_parts(
383 self,
384 ) -> (
385 UnionFields,
386 ScalarBuffer<i8>,
387 Option<ScalarBuffer<i32>>,
388 Vec<ArrayRef>,
389 ) {
390 let Self {
391 data_type,
392 type_ids,
393 offsets,
394 mut fields,
395 } = self;
396 match data_type {
397 DataType::Union(union_fields, _) => {
398 let children = union_fields
399 .iter()
400 .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
401 .collect();
402 (union_fields, type_ids, offsets, children)
403 }
404 _ => unreachable!(),
405 }
406 }
407
408 fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
410 let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
416 (
417 with_nulls_selected | is_field,
418 union_nulls | (is_field & field_nulls),
419 )
420 };
421
422 self.mask_sparse_helper(
423 nulls,
424 |type_ids_chunk_array, nulls_masks_iters| {
425 let (with_nulls_selected, union_nulls) = nulls_masks_iters
426 .iter_mut()
427 .map(|(field_type_id, field_nulls)| {
428 let field_nulls = field_nulls.next().unwrap();
429 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
430
431 (is_field, field_nulls)
432 })
433 .fold((0, 0), fold);
434
435 let without_nulls_selected = !with_nulls_selected;
437
438 without_nulls_selected | union_nulls
441 },
442 |type_ids_remainder, bit_chunks| {
443 let (with_nulls_selected, union_nulls) = bit_chunks
444 .iter()
445 .map(|(field_type_id, field_bit_chunks)| {
446 let field_nulls = field_bit_chunks.remainder_bits();
447 let is_field = selection_mask(type_ids_remainder, *field_type_id);
448
449 (is_field, field_nulls)
450 })
451 .fold((0, 0), fold);
452
453 let without_nulls_selected = !with_nulls_selected;
454
455 without_nulls_selected | union_nulls
456 },
457 )
458 }
459
460 fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
462 let fields = match self.data_type() {
463 DataType::Union(fields, _) => fields,
464 _ => unreachable!("Union array's data type is not a union!"),
465 };
466
467 let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
468 let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
469
470 let without_nulls_ids = type_ids
471 .difference(&with_nulls)
472 .copied()
473 .collect::<Vec<_>>();
474
475 nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
476
477 self.mask_sparse_helper(
482 nulls,
483 |type_ids_chunk_array, nulls_masks_iters| {
484 let union_nulls = nulls_masks_iters.iter_mut().fold(
485 0,
486 |union_nulls, (field_type_id, nulls_iter)| {
487 let field_nulls = nulls_iter.next().unwrap();
488
489 if field_nulls == 0 {
490 union_nulls
491 } else {
492 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
493
494 union_nulls | (is_field & field_nulls)
495 }
496 },
497 );
498
499 let without_nulls_selected =
501 without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
502
503 union_nulls | without_nulls_selected
506 },
507 |type_ids_remainder, bit_chunks| {
508 let union_nulls =
509 bit_chunks
510 .iter()
511 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
512 let is_field = selection_mask(type_ids_remainder, *field_type_id);
513 let field_nulls = field_bit_chunks.remainder_bits();
514
515 union_nulls | is_field & field_nulls
516 });
517
518 union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
519 },
520 )
521 }
522
523 fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
525 self.mask_sparse_helper(
532 nulls,
533 |type_ids_chunk_array, nulls_masks_iters| {
534 let (is_not_first, union_nulls) = nulls_masks_iters[1..] .iter_mut()
536 .fold(
537 (0, 0),
538 |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
539 let field_nulls = nulls_iter.next().unwrap();
540 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
541
542 (
543 is_not_first | is_field,
544 union_nulls | (is_field & field_nulls),
545 )
546 },
547 );
548
549 let is_first = !is_not_first;
550 let first_nulls = nulls_masks_iters[0].1.next().unwrap();
551
552 (is_first & first_nulls) | union_nulls
553 },
554 |type_ids_remainder, bit_chunks| {
555 bit_chunks
556 .iter()
557 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
558 let field_nulls = field_bit_chunks.remainder_bits();
559 let is_field = selection_mask(type_ids_remainder, *field_type_id);
562
563 union_nulls | (is_field & field_nulls)
564 })
565 },
566 )
567 }
568
569 fn mask_sparse_helper(
572 &self,
573 nulls: Vec<(i8, NullBuffer)>,
574 mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
575 mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
576 ) -> BooleanBuffer {
577 let bit_chunks = nulls
578 .iter()
579 .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
580 .collect::<Vec<_>>();
581
582 let mut nulls_masks_iter = bit_chunks
583 .iter()
584 .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
585 .collect::<Vec<_>>();
586
587 let chunks_exact = self.type_ids.chunks_exact(64);
588 let remainder = chunks_exact.remainder();
589
590 let chunks = chunks_exact.map(|type_ids_chunk| {
591 let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
592
593 mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
594 });
595
596 let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
599
600 if !remainder.is_empty() {
601 buffer.push(mask_remainder(remainder, &bit_chunks));
602 }
603
604 BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
605 }
606
607 fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
609 let one_null = NullBuffer::new_null(1);
610 let one_valid = NullBuffer::new_valid(1);
611
612 let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
619
620 for (type_id, nulls) in &nulls {
621 if nulls.null_count() == nulls.len() {
622 logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
624 } else {
625 logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
626 }
627 }
628
629 match &self.offsets {
630 Some(offsets) => {
631 assert_eq!(self.type_ids.len(), offsets.len());
632
633 BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
634 let type_id = *self.type_ids.get_unchecked(i);
636 let offset = *offsets.get_unchecked(i);
638
639 let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
640
641 nulls
647 .inner()
648 .value_unchecked(offset as usize & *offset_mask as usize)
649 })
650 }
651 None => {
652 BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
653 let type_id = *self.type_ids.get_unchecked(index);
655
656 let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
657
658 nulls.inner().value_unchecked(index & *index_mask as usize)
664 })
665 }
666 }
667 }
668
669 fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
672 self.fields
673 .iter()
674 .enumerate()
675 .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
676 .filter(|(_, nulls)| nulls.null_count() > 0)
677 .collect()
678 }
679}
680
681impl From<ArrayData> for UnionArray {
682 fn from(data: ArrayData) -> Self {
683 let (fields, mode) = match data.data_type() {
684 DataType::Union(fields, mode) => (fields, *mode),
685 d => panic!("UnionArray expected ArrayData with type Union got {d}"),
686 };
687 let (type_ids, offsets) = match mode {
688 UnionMode::Sparse => (
689 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
690 None,
691 ),
692 UnionMode::Dense => (
693 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
694 Some(ScalarBuffer::new(
695 data.buffers()[1].clone(),
696 data.offset(),
697 data.len(),
698 )),
699 ),
700 };
701
702 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
703 let mut boxed_fields = vec![None; max_id + 1];
704 for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
705 boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
706 }
707 Self {
708 data_type: data.data_type().clone(),
709 type_ids,
710 offsets,
711 fields: boxed_fields,
712 }
713 }
714}
715
716impl From<UnionArray> for ArrayData {
717 fn from(array: UnionArray) -> Self {
718 let len = array.len();
719 let f = match &array.data_type {
720 DataType::Union(f, _) => f,
721 _ => unreachable!(),
722 };
723 let buffers = match array.offsets {
724 Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
725 None => vec![array.type_ids.into_inner()],
726 };
727
728 let child = f
729 .iter()
730 .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
731 .collect();
732
733 let builder = ArrayDataBuilder::new(array.data_type)
734 .len(len)
735 .buffers(buffers)
736 .child_data(child);
737 unsafe { builder.build_unchecked() }
738 }
739}
740
741impl super::private::Sealed for UnionArray {}
742
743impl Array for UnionArray {
744 fn as_any(&self) -> &dyn Any {
745 self
746 }
747
748 fn to_data(&self) -> ArrayData {
749 self.clone().into()
750 }
751
752 fn into_data(self) -> ArrayData {
753 self.into()
754 }
755
756 fn data_type(&self) -> &DataType {
757 &self.data_type
758 }
759
760 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
761 Arc::new(self.slice(offset, length))
762 }
763
764 fn len(&self) -> usize {
765 self.type_ids.len()
766 }
767
768 fn is_empty(&self) -> bool {
769 self.type_ids.is_empty()
770 }
771
772 fn shrink_to_fit(&mut self) {
773 self.type_ids.shrink_to_fit();
774 if let Some(offsets) = &mut self.offsets {
775 offsets.shrink_to_fit();
776 }
777 for array in self.fields.iter_mut().flatten() {
778 array.shrink_to_fit();
779 }
780 self.fields.shrink_to_fit();
781 }
782
783 fn offset(&self) -> usize {
784 0
785 }
786
787 fn nulls(&self) -> Option<&NullBuffer> {
788 None
789 }
790
791 fn logical_nulls(&self) -> Option<NullBuffer> {
792 let fields = match self.data_type() {
793 DataType::Union(fields, _) => fields,
794 _ => unreachable!(),
795 };
796
797 if fields.len() <= 1 {
798 return self.fields.iter().find_map(|field_opt| {
799 field_opt
800 .as_ref()
801 .and_then(|field| field.logical_nulls())
802 .map(|logical_nulls| {
803 if self.is_dense() {
804 self.gather_nulls(vec![(0, logical_nulls)]).into()
805 } else {
806 logical_nulls
807 }
808 })
809 });
810 }
811
812 let logical_nulls = self.fields_logical_nulls();
813
814 if logical_nulls.is_empty() {
815 return None;
816 }
817
818 let fully_null_count = logical_nulls
819 .iter()
820 .filter(|(_, nulls)| nulls.null_count() == nulls.len())
821 .count();
822
823 if fully_null_count == fields.len() {
824 if let Some((_, exactly_sized)) = logical_nulls
825 .iter()
826 .find(|(_, nulls)| nulls.len() == self.len())
827 {
828 return Some(exactly_sized.clone());
829 }
830
831 if let Some((_, bigger)) = logical_nulls
832 .iter()
833 .find(|(_, nulls)| nulls.len() > self.len())
834 {
835 return Some(bigger.slice(0, self.len()));
836 }
837
838 return Some(NullBuffer::new_null(self.len()));
839 }
840
841 let boolean_buffer = match &self.offsets {
842 Some(_) => self.gather_nulls(logical_nulls),
843 None => {
844 let gather_relative_cost = if cfg!(target_feature = "avx2") {
852 10
853 } else if cfg!(target_feature = "sse4.1") {
854 3
855 } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
856 2
858 } else {
859 0
863 };
864
865 let strategies = [
866 (SparseStrategy::Gather, gather_relative_cost, true),
867 (
868 SparseStrategy::MaskAllFieldsWithNullsSkipOne,
869 fields.len() - 1,
870 fields.len() == logical_nulls.len(),
871 ),
872 (
873 SparseStrategy::MaskSkipWithoutNulls,
874 logical_nulls.len(),
875 true,
876 ),
877 (
878 SparseStrategy::MaskSkipFullyNull,
879 fields.len() - fully_null_count,
880 true,
881 ),
882 ];
883
884 let (strategy, _, _) = strategies
885 .iter()
886 .filter(|(_, _, applicable)| *applicable)
887 .min_by_key(|(_, cost, _)| cost)
888 .unwrap();
889
890 match strategy {
891 SparseStrategy::Gather => self.gather_nulls(logical_nulls),
892 SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
893 self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
894 }
895 SparseStrategy::MaskSkipWithoutNulls => {
896 self.mask_sparse_skip_without_nulls(logical_nulls)
897 }
898 SparseStrategy::MaskSkipFullyNull => {
899 self.mask_sparse_skip_fully_null(logical_nulls)
900 }
901 }
902 }
903 };
904
905 let null_buffer = NullBuffer::from(boolean_buffer);
906
907 if null_buffer.null_count() > 0 {
908 Some(null_buffer)
909 } else {
910 None
911 }
912 }
913
914 fn is_nullable(&self) -> bool {
915 self.fields
916 .iter()
917 .flatten()
918 .any(|field| field.is_nullable())
919 }
920
921 fn get_buffer_memory_size(&self) -> usize {
922 let mut sum = self.type_ids.inner().capacity();
923 if let Some(o) = self.offsets.as_ref() {
924 sum += o.inner().capacity()
925 }
926 self.fields
927 .iter()
928 .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
929 .sum::<usize>()
930 + sum
931 }
932
933 fn get_array_memory_size(&self) -> usize {
934 let mut sum = self.type_ids.inner().capacity();
935 if let Some(o) = self.offsets.as_ref() {
936 sum += o.inner().capacity()
937 }
938 std::mem::size_of::<Self>()
939 + self
940 .fields
941 .iter()
942 .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
943 .sum::<usize>()
944 + sum
945 }
946}
947
948impl std::fmt::Debug for UnionArray {
949 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
950 let header = if self.is_dense() {
951 "UnionArray(Dense)\n["
952 } else {
953 "UnionArray(Sparse)\n["
954 };
955 writeln!(f, "{header}")?;
956
957 writeln!(f, "-- type id buffer:")?;
958 writeln!(f, "{:?}", self.type_ids)?;
959
960 if let Some(offsets) = &self.offsets {
961 writeln!(f, "-- offsets buffer:")?;
962 writeln!(f, "{offsets:?}")?;
963 }
964
965 let fields = match self.data_type() {
966 DataType::Union(fields, _) => fields,
967 _ => unreachable!(),
968 };
969
970 for (type_id, field) in fields.iter() {
971 let child = self.child(type_id);
972 writeln!(
973 f,
974 "-- child {}: \"{}\" ({:?})",
975 type_id,
976 field.name(),
977 field.data_type()
978 )?;
979 std::fmt::Debug::fmt(child, f)?;
980 writeln!(f)?;
981 }
982 writeln!(f, "]")
983 }
984}
985
986enum SparseStrategy {
991 Gather,
993 MaskAllFieldsWithNullsSkipOne,
995 MaskSkipWithoutNulls,
997 MaskSkipFullyNull,
999}
1000
1001#[derive(Copy, Clone)]
1002#[repr(usize)]
1003enum Mask {
1004 Zero = 0,
1005 #[allow(clippy::enum_clike_unportable_variant)]
1007 Max = usize::MAX,
1008}
1009
1010fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
1011 type_ids_chunk
1012 .iter()
1013 .copied()
1014 .enumerate()
1015 .fold(0, |packed, (bit_idx, v)| {
1016 packed | (((v == type_id) as u64) << bit_idx)
1017 })
1018}
1019
1020fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1022 without_nulls_ids
1023 .iter()
1024 .fold(0, |fully_valid_selected, field_type_id| {
1025 fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1026 })
1027}
1028
1029#[cfg(test)]
1030mod tests {
1031 use super::*;
1032 use std::collections::HashSet;
1033
1034 use crate::array::Int8Type;
1035 use crate::builder::UnionBuilder;
1036 use crate::cast::AsArray;
1037 use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1038 use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1039 use crate::{Int8Array, RecordBatch};
1040 use arrow_buffer::Buffer;
1041 use arrow_schema::{Field, Schema};
1042
1043 #[test]
1044 fn test_dense_i32() {
1045 let mut builder = UnionBuilder::new_dense();
1046 builder.append::<Int32Type>("a", 1).unwrap();
1047 builder.append::<Int32Type>("b", 2).unwrap();
1048 builder.append::<Int32Type>("c", 3).unwrap();
1049 builder.append::<Int32Type>("a", 4).unwrap();
1050 builder.append::<Int32Type>("c", 5).unwrap();
1051 builder.append::<Int32Type>("a", 6).unwrap();
1052 builder.append::<Int32Type>("b", 7).unwrap();
1053 let union = builder.build().unwrap();
1054
1055 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1056 let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1057 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1058
1059 assert_eq!(*union.type_ids(), expected_type_ids);
1061 for (i, id) in expected_type_ids.iter().enumerate() {
1062 assert_eq!(id, &union.type_id(i));
1063 }
1064
1065 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1067 for (i, id) in expected_offsets.iter().enumerate() {
1068 assert_eq!(union.value_offset(i), *id as usize);
1069 }
1070
1071 assert_eq!(
1073 *union.child(0).as_primitive::<Int32Type>().values(),
1074 [1_i32, 4, 6]
1075 );
1076 assert_eq!(
1077 *union.child(1).as_primitive::<Int32Type>().values(),
1078 [2_i32, 7]
1079 );
1080 assert_eq!(
1081 *union.child(2).as_primitive::<Int32Type>().values(),
1082 [3_i32, 5]
1083 );
1084
1085 assert_eq!(expected_array_values.len(), union.len());
1086 for (i, expected_value) in expected_array_values.iter().enumerate() {
1087 assert!(!union.is_null(i));
1088 let slot = union.value(i);
1089 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1090 assert_eq!(slot.len(), 1);
1091 let value = slot.value(0);
1092 assert_eq!(expected_value, &value);
1093 }
1094 }
1095
1096 #[test]
1097 fn slice_union_array_single_field() {
1098 let union_array = {
1101 let mut builder = UnionBuilder::new_dense();
1102 builder.append::<Int32Type>("a", 1).unwrap();
1103 builder.append_null::<Int32Type>("a").unwrap();
1104 builder.append::<Int32Type>("a", 3).unwrap();
1105 builder.append_null::<Int32Type>("a").unwrap();
1106 builder.append::<Int32Type>("a", 4).unwrap();
1107 builder.build().unwrap()
1108 };
1109
1110 let union_slice = union_array.slice(1, 3);
1112 let logical_nulls = union_slice.logical_nulls().unwrap();
1113
1114 assert_eq!(logical_nulls.len(), 3);
1115 assert!(logical_nulls.is_null(0));
1116 assert!(logical_nulls.is_valid(1));
1117 assert!(logical_nulls.is_null(2));
1118 }
1119
1120 #[test]
1121 #[cfg_attr(miri, ignore)]
1122 fn test_dense_i32_large() {
1123 let mut builder = UnionBuilder::new_dense();
1124
1125 let expected_type_ids = vec![0_i8; 1024];
1126 let expected_offsets: Vec<_> = (0..1024).collect();
1127 let expected_array_values: Vec<_> = (1..=1024).collect();
1128
1129 expected_array_values
1130 .iter()
1131 .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1132
1133 let union = builder.build().unwrap();
1134
1135 assert_eq!(*union.type_ids(), expected_type_ids);
1137 for (i, id) in expected_type_ids.iter().enumerate() {
1138 assert_eq!(id, &union.type_id(i));
1139 }
1140
1141 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1143 for (i, id) in expected_offsets.iter().enumerate() {
1144 assert_eq!(union.value_offset(i), *id as usize);
1145 }
1146
1147 for (i, expected_value) in expected_array_values.iter().enumerate() {
1148 assert!(!union.is_null(i));
1149 let slot = union.value(i);
1150 let slot = slot.as_primitive::<Int32Type>();
1151 assert_eq!(slot.len(), 1);
1152 let value = slot.value(0);
1153 assert_eq!(expected_value, &value);
1154 }
1155 }
1156
1157 #[test]
1158 fn test_dense_mixed() {
1159 let mut builder = UnionBuilder::new_dense();
1160 builder.append::<Int32Type>("a", 1).unwrap();
1161 builder.append::<Int64Type>("c", 3).unwrap();
1162 builder.append::<Int32Type>("a", 4).unwrap();
1163 builder.append::<Int64Type>("c", 5).unwrap();
1164 builder.append::<Int32Type>("a", 6).unwrap();
1165 let union = builder.build().unwrap();
1166
1167 assert_eq!(5, union.len());
1168 for i in 0..union.len() {
1169 let slot = union.value(i);
1170 assert!(!union.is_null(i));
1171 match i {
1172 0 => {
1173 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1174 assert_eq!(slot.len(), 1);
1175 let value = slot.value(0);
1176 assert_eq!(1_i32, value);
1177 }
1178 1 => {
1179 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1180 assert_eq!(slot.len(), 1);
1181 let value = slot.value(0);
1182 assert_eq!(3_i64, value);
1183 }
1184 2 => {
1185 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1186 assert_eq!(slot.len(), 1);
1187 let value = slot.value(0);
1188 assert_eq!(4_i32, value);
1189 }
1190 3 => {
1191 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1192 assert_eq!(slot.len(), 1);
1193 let value = slot.value(0);
1194 assert_eq!(5_i64, value);
1195 }
1196 4 => {
1197 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1198 assert_eq!(slot.len(), 1);
1199 let value = slot.value(0);
1200 assert_eq!(6_i32, value);
1201 }
1202 _ => unreachable!(),
1203 }
1204 }
1205 }
1206
1207 #[test]
1208 fn test_dense_mixed_with_nulls() {
1209 let mut builder = UnionBuilder::new_dense();
1210 builder.append::<Int32Type>("a", 1).unwrap();
1211 builder.append::<Int64Type>("c", 3).unwrap();
1212 builder.append::<Int32Type>("a", 10).unwrap();
1213 builder.append_null::<Int32Type>("a").unwrap();
1214 builder.append::<Int32Type>("a", 6).unwrap();
1215 let union = builder.build().unwrap();
1216
1217 assert_eq!(5, union.len());
1218 for i in 0..union.len() {
1219 let slot = union.value(i);
1220 match i {
1221 0 => {
1222 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1223 assert!(!slot.is_null(0));
1224 assert_eq!(slot.len(), 1);
1225 let value = slot.value(0);
1226 assert_eq!(1_i32, value);
1227 }
1228 1 => {
1229 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1230 assert!(!slot.is_null(0));
1231 assert_eq!(slot.len(), 1);
1232 let value = slot.value(0);
1233 assert_eq!(3_i64, value);
1234 }
1235 2 => {
1236 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1237 assert!(!slot.is_null(0));
1238 assert_eq!(slot.len(), 1);
1239 let value = slot.value(0);
1240 assert_eq!(10_i32, value);
1241 }
1242 3 => assert!(slot.is_null(0)),
1243 4 => {
1244 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1245 assert!(!slot.is_null(0));
1246 assert_eq!(slot.len(), 1);
1247 let value = slot.value(0);
1248 assert_eq!(6_i32, value);
1249 }
1250 _ => unreachable!(),
1251 }
1252 }
1253 }
1254
1255 #[test]
1256 fn test_dense_mixed_with_nulls_and_offset() {
1257 let mut builder = UnionBuilder::new_dense();
1258 builder.append::<Int32Type>("a", 1).unwrap();
1259 builder.append::<Int64Type>("c", 3).unwrap();
1260 builder.append::<Int32Type>("a", 10).unwrap();
1261 builder.append_null::<Int32Type>("a").unwrap();
1262 builder.append::<Int32Type>("a", 6).unwrap();
1263 let union = builder.build().unwrap();
1264
1265 let slice = union.slice(2, 3);
1266 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1267
1268 assert_eq!(3, new_union.len());
1269 for i in 0..new_union.len() {
1270 let slot = new_union.value(i);
1271 match i {
1272 0 => {
1273 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1274 assert!(!slot.is_null(0));
1275 assert_eq!(slot.len(), 1);
1276 let value = slot.value(0);
1277 assert_eq!(10_i32, value);
1278 }
1279 1 => assert!(slot.is_null(0)),
1280 2 => {
1281 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1282 assert!(!slot.is_null(0));
1283 assert_eq!(slot.len(), 1);
1284 let value = slot.value(0);
1285 assert_eq!(6_i32, value);
1286 }
1287 _ => unreachable!(),
1288 }
1289 }
1290 }
1291
1292 #[test]
1293 fn test_dense_mixed_with_str() {
1294 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1295 let int_array = Int32Array::from(vec![5, 6]);
1296 let float_array = Float64Array::from(vec![10.0]);
1297
1298 let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1299 let offsets = [0, 0, 1, 0, 2, 1]
1300 .into_iter()
1301 .collect::<ScalarBuffer<i32>>();
1302
1303 let fields = [
1304 (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1305 (1, Arc::new(Field::new("B", DataType::Int32, false))),
1306 (2, Arc::new(Field::new("C", DataType::Float64, false))),
1307 ]
1308 .into_iter()
1309 .collect::<UnionFields>();
1310 let children = [
1311 Arc::new(string_array) as Arc<dyn Array>,
1312 Arc::new(int_array),
1313 Arc::new(float_array),
1314 ]
1315 .into_iter()
1316 .collect();
1317 let array =
1318 UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1319
1320 assert_eq!(*array.type_ids(), type_ids);
1322 for (i, id) in type_ids.iter().enumerate() {
1323 assert_eq!(id, &array.type_id(i));
1324 }
1325
1326 assert_eq!(*array.offsets().unwrap(), offsets);
1328 for (i, id) in offsets.iter().enumerate() {
1329 assert_eq!(*id as usize, array.value_offset(i));
1330 }
1331
1332 assert_eq!(6, array.len());
1334
1335 let slot = array.value(0);
1336 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1337 assert_eq!(5, value);
1338
1339 let slot = array.value(1);
1340 let value = slot
1341 .as_any()
1342 .downcast_ref::<StringArray>()
1343 .unwrap()
1344 .value(0);
1345 assert_eq!("foo", value);
1346
1347 let slot = array.value(2);
1348 let value = slot
1349 .as_any()
1350 .downcast_ref::<StringArray>()
1351 .unwrap()
1352 .value(0);
1353 assert_eq!("bar", value);
1354
1355 let slot = array.value(3);
1356 let value = slot
1357 .as_any()
1358 .downcast_ref::<Float64Array>()
1359 .unwrap()
1360 .value(0);
1361 assert_eq!(10.0, value);
1362
1363 let slot = array.value(4);
1364 let value = slot
1365 .as_any()
1366 .downcast_ref::<StringArray>()
1367 .unwrap()
1368 .value(0);
1369 assert_eq!("baz", value);
1370
1371 let slot = array.value(5);
1372 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1373 assert_eq!(6, value);
1374 }
1375
1376 #[test]
1377 fn test_sparse_i32() {
1378 let mut builder = UnionBuilder::new_sparse();
1379 builder.append::<Int32Type>("a", 1).unwrap();
1380 builder.append::<Int32Type>("b", 2).unwrap();
1381 builder.append::<Int32Type>("c", 3).unwrap();
1382 builder.append::<Int32Type>("a", 4).unwrap();
1383 builder.append::<Int32Type>("c", 5).unwrap();
1384 builder.append::<Int32Type>("a", 6).unwrap();
1385 builder.append::<Int32Type>("b", 7).unwrap();
1386 let union = builder.build().unwrap();
1387
1388 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1389 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1390
1391 assert_eq!(*union.type_ids(), expected_type_ids);
1393 for (i, id) in expected_type_ids.iter().enumerate() {
1394 assert_eq!(id, &union.type_id(i));
1395 }
1396
1397 assert!(union.offsets().is_none());
1399
1400 assert_eq!(
1402 *union.child(0).as_primitive::<Int32Type>().values(),
1403 [1_i32, 0, 0, 4, 0, 6, 0],
1404 );
1405 assert_eq!(
1406 *union.child(1).as_primitive::<Int32Type>().values(),
1407 [0_i32, 2_i32, 0, 0, 0, 0, 7]
1408 );
1409 assert_eq!(
1410 *union.child(2).as_primitive::<Int32Type>().values(),
1411 [0_i32, 0, 3_i32, 0, 5, 0, 0]
1412 );
1413
1414 assert_eq!(expected_array_values.len(), union.len());
1415 for (i, expected_value) in expected_array_values.iter().enumerate() {
1416 assert!(!union.is_null(i));
1417 let slot = union.value(i);
1418 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1419 assert_eq!(slot.len(), 1);
1420 let value = slot.value(0);
1421 assert_eq!(expected_value, &value);
1422 }
1423 }
1424
1425 #[test]
1426 fn test_sparse_mixed() {
1427 let mut builder = UnionBuilder::new_sparse();
1428 builder.append::<Int32Type>("a", 1).unwrap();
1429 builder.append::<Float64Type>("c", 3.0).unwrap();
1430 builder.append::<Int32Type>("a", 4).unwrap();
1431 builder.append::<Float64Type>("c", 5.0).unwrap();
1432 builder.append::<Int32Type>("a", 6).unwrap();
1433 let union = builder.build().unwrap();
1434
1435 let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1436
1437 assert_eq!(*union.type_ids(), expected_type_ids);
1439 for (i, id) in expected_type_ids.iter().enumerate() {
1440 assert_eq!(id, &union.type_id(i));
1441 }
1442
1443 assert!(union.offsets().is_none());
1445
1446 for i in 0..union.len() {
1447 let slot = union.value(i);
1448 assert!(!union.is_null(i));
1449 match i {
1450 0 => {
1451 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1452 assert_eq!(slot.len(), 1);
1453 let value = slot.value(0);
1454 assert_eq!(1_i32, value);
1455 }
1456 1 => {
1457 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1458 assert_eq!(slot.len(), 1);
1459 let value = slot.value(0);
1460 assert_eq!(value, 3_f64);
1461 }
1462 2 => {
1463 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1464 assert_eq!(slot.len(), 1);
1465 let value = slot.value(0);
1466 assert_eq!(4_i32, value);
1467 }
1468 3 => {
1469 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1470 assert_eq!(slot.len(), 1);
1471 let value = slot.value(0);
1472 assert_eq!(5_f64, value);
1473 }
1474 4 => {
1475 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1476 assert_eq!(slot.len(), 1);
1477 let value = slot.value(0);
1478 assert_eq!(6_i32, value);
1479 }
1480 _ => unreachable!(),
1481 }
1482 }
1483 }
1484
1485 #[test]
1486 fn test_sparse_mixed_with_nulls() {
1487 let mut builder = UnionBuilder::new_sparse();
1488 builder.append::<Int32Type>("a", 1).unwrap();
1489 builder.append_null::<Int32Type>("a").unwrap();
1490 builder.append::<Float64Type>("c", 3.0).unwrap();
1491 builder.append::<Int32Type>("a", 4).unwrap();
1492 let union = builder.build().unwrap();
1493
1494 let expected_type_ids = vec![0_i8, 0, 1, 0];
1495
1496 assert_eq!(*union.type_ids(), expected_type_ids);
1498 for (i, id) in expected_type_ids.iter().enumerate() {
1499 assert_eq!(id, &union.type_id(i));
1500 }
1501
1502 assert!(union.offsets().is_none());
1504
1505 for i in 0..union.len() {
1506 let slot = union.value(i);
1507 match i {
1508 0 => {
1509 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1510 assert!(!slot.is_null(0));
1511 assert_eq!(slot.len(), 1);
1512 let value = slot.value(0);
1513 assert_eq!(1_i32, value);
1514 }
1515 1 => assert!(slot.is_null(0)),
1516 2 => {
1517 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1518 assert!(!slot.is_null(0));
1519 assert_eq!(slot.len(), 1);
1520 let value = slot.value(0);
1521 assert_eq!(value, 3_f64);
1522 }
1523 3 => {
1524 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1525 assert!(!slot.is_null(0));
1526 assert_eq!(slot.len(), 1);
1527 let value = slot.value(0);
1528 assert_eq!(4_i32, value);
1529 }
1530 _ => unreachable!(),
1531 }
1532 }
1533 }
1534
1535 #[test]
1536 fn test_sparse_mixed_with_nulls_and_offset() {
1537 let mut builder = UnionBuilder::new_sparse();
1538 builder.append::<Int32Type>("a", 1).unwrap();
1539 builder.append_null::<Int32Type>("a").unwrap();
1540 builder.append::<Float64Type>("c", 3.0).unwrap();
1541 builder.append_null::<Float64Type>("c").unwrap();
1542 builder.append::<Int32Type>("a", 4).unwrap();
1543 let union = builder.build().unwrap();
1544
1545 let slice = union.slice(1, 4);
1546 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1547
1548 assert_eq!(4, new_union.len());
1549 for i in 0..new_union.len() {
1550 let slot = new_union.value(i);
1551 match i {
1552 0 => assert!(slot.is_null(0)),
1553 1 => {
1554 let slot = slot.as_primitive::<Float64Type>();
1555 assert!(!slot.is_null(0));
1556 assert_eq!(slot.len(), 1);
1557 let value = slot.value(0);
1558 assert_eq!(value, 3_f64);
1559 }
1560 2 => assert!(slot.is_null(0)),
1561 3 => {
1562 let slot = slot.as_primitive::<Int32Type>();
1563 assert!(!slot.is_null(0));
1564 assert_eq!(slot.len(), 1);
1565 let value = slot.value(0);
1566 assert_eq!(4_i32, value);
1567 }
1568 _ => unreachable!(),
1569 }
1570 }
1571 }
1572
1573 fn test_union_validity(union_array: &UnionArray) {
1574 assert_eq!(union_array.null_count(), 0);
1575
1576 for i in 0..union_array.len() {
1577 assert!(!union_array.is_null(i));
1578 assert!(union_array.is_valid(i));
1579 }
1580 }
1581
1582 #[test]
1583 fn test_union_array_validity() {
1584 let mut builder = UnionBuilder::new_sparse();
1585 builder.append::<Int32Type>("a", 1).unwrap();
1586 builder.append_null::<Int32Type>("a").unwrap();
1587 builder.append::<Float64Type>("c", 3.0).unwrap();
1588 builder.append_null::<Float64Type>("c").unwrap();
1589 builder.append::<Int32Type>("a", 4).unwrap();
1590 let union = builder.build().unwrap();
1591
1592 test_union_validity(&union);
1593
1594 let mut builder = UnionBuilder::new_dense();
1595 builder.append::<Int32Type>("a", 1).unwrap();
1596 builder.append_null::<Int32Type>("a").unwrap();
1597 builder.append::<Float64Type>("c", 3.0).unwrap();
1598 builder.append_null::<Float64Type>("c").unwrap();
1599 builder.append::<Int32Type>("a", 4).unwrap();
1600 let union = builder.build().unwrap();
1601
1602 test_union_validity(&union);
1603 }
1604
1605 #[test]
1606 fn test_type_check() {
1607 let mut builder = UnionBuilder::new_sparse();
1608 builder.append::<Float32Type>("a", 1.0).unwrap();
1609 let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1610 assert!(
1611 err.contains(
1612 "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1613 ),
1614 "{}",
1615 err
1616 );
1617 }
1618
1619 #[test]
1620 fn slice_union_array() {
1621 fn create_union(mut builder: UnionBuilder) -> UnionArray {
1623 builder.append::<Int32Type>("a", 1).unwrap();
1624 builder.append_null::<Int32Type>("a").unwrap();
1625 builder.append::<Float64Type>("c", 3.0).unwrap();
1626 builder.append_null::<Float64Type>("c").unwrap();
1627 builder.append::<Int32Type>("a", 4).unwrap();
1628 builder.build().unwrap()
1629 }
1630
1631 fn create_batch(union: UnionArray) -> RecordBatch {
1632 let schema = Schema::new(vec![Field::new(
1633 "struct_array",
1634 union.data_type().clone(),
1635 true,
1636 )]);
1637
1638 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1639 }
1640
1641 fn test_slice_union(record_batch_slice: RecordBatch) {
1642 let union_slice = record_batch_slice
1643 .column(0)
1644 .as_any()
1645 .downcast_ref::<UnionArray>()
1646 .unwrap();
1647
1648 assert_eq!(union_slice.type_id(0), 0);
1649 assert_eq!(union_slice.type_id(1), 1);
1650 assert_eq!(union_slice.type_id(2), 1);
1651
1652 let slot = union_slice.value(0);
1653 let array = slot.as_primitive::<Int32Type>();
1654 assert_eq!(array.len(), 1);
1655 assert!(array.is_null(0));
1656
1657 let slot = union_slice.value(1);
1658 let array = slot.as_primitive::<Float64Type>();
1659 assert_eq!(array.len(), 1);
1660 assert!(array.is_valid(0));
1661 assert_eq!(array.value(0), 3.0);
1662
1663 let slot = union_slice.value(2);
1664 let array = slot.as_primitive::<Float64Type>();
1665 assert_eq!(array.len(), 1);
1666 assert!(array.is_null(0));
1667 }
1668
1669 let builder = UnionBuilder::new_sparse();
1671 let record_batch = create_batch(create_union(builder));
1672 let record_batch_slice = record_batch.slice(1, 3);
1674 test_slice_union(record_batch_slice);
1675
1676 let builder = UnionBuilder::new_dense();
1678 let record_batch = create_batch(create_union(builder));
1679 let record_batch_slice = record_batch.slice(1, 3);
1681 test_slice_union(record_batch_slice);
1682 }
1683
1684 #[test]
1685 fn test_custom_type_ids() {
1686 let data_type = DataType::Union(
1687 UnionFields::try_new(
1688 vec![8, 4, 9],
1689 vec![
1690 Field::new("strings", DataType::Utf8, false),
1691 Field::new("integers", DataType::Int32, false),
1692 Field::new("floats", DataType::Float64, false),
1693 ],
1694 )
1695 .unwrap(),
1696 UnionMode::Dense,
1697 );
1698
1699 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1700 let int_array = Int32Array::from(vec![5, 6, 4]);
1701 let float_array = Float64Array::from(vec![10.0]);
1702
1703 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1704 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1705
1706 let data = ArrayData::builder(data_type)
1707 .len(7)
1708 .buffers(vec![type_ids, value_offsets])
1709 .child_data(vec![
1710 string_array.into_data(),
1711 int_array.into_data(),
1712 float_array.into_data(),
1713 ])
1714 .build()
1715 .unwrap();
1716
1717 let array = UnionArray::from(data);
1718
1719 let v = array.value(0);
1720 assert_eq!(v.data_type(), &DataType::Int32);
1721 assert_eq!(v.len(), 1);
1722 assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1723
1724 let v = array.value(1);
1725 assert_eq!(v.data_type(), &DataType::Utf8);
1726 assert_eq!(v.len(), 1);
1727 assert_eq!(v.as_string::<i32>().value(0), "foo");
1728
1729 let v = array.value(2);
1730 assert_eq!(v.data_type(), &DataType::Int32);
1731 assert_eq!(v.len(), 1);
1732 assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1733
1734 let v = array.value(3);
1735 assert_eq!(v.data_type(), &DataType::Utf8);
1736 assert_eq!(v.len(), 1);
1737 assert_eq!(v.as_string::<i32>().value(0), "bar");
1738
1739 let v = array.value(4);
1740 assert_eq!(v.data_type(), &DataType::Float64);
1741 assert_eq!(v.len(), 1);
1742 assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1743
1744 let v = array.value(5);
1745 assert_eq!(v.data_type(), &DataType::Int32);
1746 assert_eq!(v.len(), 1);
1747 assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1748
1749 let v = array.value(6);
1750 assert_eq!(v.data_type(), &DataType::Utf8);
1751 assert_eq!(v.len(), 1);
1752 assert_eq!(v.as_string::<i32>().value(0), "baz");
1753 }
1754
1755 #[test]
1756 fn into_parts() {
1757 let mut builder = UnionBuilder::new_dense();
1758 builder.append::<Int32Type>("a", 1).unwrap();
1759 builder.append::<Int8Type>("b", 2).unwrap();
1760 builder.append::<Int32Type>("a", 3).unwrap();
1761 let dense_union = builder.build().unwrap();
1762
1763 let field = [
1764 &Arc::new(Field::new("a", DataType::Int32, false)),
1765 &Arc::new(Field::new("b", DataType::Int8, false)),
1766 ];
1767 let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1768 assert_eq!(
1769 union_fields
1770 .iter()
1771 .map(|(_, field)| field)
1772 .collect::<Vec<_>>(),
1773 field
1774 );
1775 assert_eq!(type_ids, [0, 1, 0]);
1776 assert!(offsets.is_some());
1777 assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1778
1779 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1780 assert!(result.is_ok());
1781 assert_eq!(result.unwrap().len(), 3);
1782
1783 let mut builder = UnionBuilder::new_sparse();
1784 builder.append::<Int32Type>("a", 1).unwrap();
1785 builder.append::<Int8Type>("b", 2).unwrap();
1786 builder.append::<Int32Type>("a", 3).unwrap();
1787 let sparse_union = builder.build().unwrap();
1788
1789 let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1790 assert_eq!(type_ids, [0, 1, 0]);
1791 assert!(offsets.is_none());
1792
1793 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1794 assert!(result.is_ok());
1795 assert_eq!(result.unwrap().len(), 3);
1796 }
1797
1798 #[test]
1799 fn into_parts_custom_type_ids() {
1800 let set_field_type_ids: [i8; 3] = [8, 4, 9];
1801 let data_type = DataType::Union(
1802 UnionFields::try_new(
1803 set_field_type_ids,
1804 [
1805 Field::new("strings", DataType::Utf8, false),
1806 Field::new("integers", DataType::Int32, false),
1807 Field::new("floats", DataType::Float64, false),
1808 ],
1809 )
1810 .unwrap(),
1811 UnionMode::Dense,
1812 );
1813 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1814 let int_array = Int32Array::from(vec![5, 6, 4]);
1815 let float_array = Float64Array::from(vec![10.0]);
1816 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1817 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1818 let data = ArrayData::builder(data_type)
1819 .len(7)
1820 .buffers(vec![type_ids, value_offsets])
1821 .child_data(vec![
1822 string_array.into_data(),
1823 int_array.into_data(),
1824 float_array.into_data(),
1825 ])
1826 .build()
1827 .unwrap();
1828 let array = UnionArray::from(data);
1829
1830 let (union_fields, type_ids, offsets, children) = array.into_parts();
1831 assert_eq!(
1832 type_ids.iter().collect::<HashSet<_>>(),
1833 set_field_type_ids.iter().collect::<HashSet<_>>()
1834 );
1835 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1836 assert!(result.is_ok());
1837 let array = result.unwrap();
1838 assert_eq!(array.len(), 7);
1839 }
1840
1841 #[test]
1842 fn test_invalid() {
1843 let fields = UnionFields::try_new(
1844 [3, 2],
1845 [
1846 Field::new("a", DataType::Utf8, false),
1847 Field::new("b", DataType::Utf8, false),
1848 ],
1849 )
1850 .unwrap();
1851 let children = vec![
1852 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1853 Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1854 ];
1855
1856 let type_ids = vec![3, 3, 2].into();
1857 let err =
1858 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1859 assert_eq!(
1860 err.to_string(),
1861 "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1862 );
1863
1864 let type_ids = vec![1, 2].into();
1865 let err =
1866 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1867 assert_eq!(
1868 err.to_string(),
1869 "Invalid argument error: Type Ids values must match one of the field type ids"
1870 );
1871
1872 let type_ids = vec![7, 2].into();
1873 let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1874 assert_eq!(
1875 err.to_string(),
1876 "Invalid argument error: Type Ids values must match one of the field type ids"
1877 );
1878
1879 let children = vec![
1880 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1881 Arc::new(StringArray::from_iter_values(["c"])) as _,
1882 ];
1883 let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1884 let offsets = Some(vec![0, 1, 0].into());
1885 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1886
1887 let offsets = Some(vec![0, 1, 1].into());
1888 let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1889 .unwrap_err();
1890
1891 assert_eq!(
1892 err.to_string(),
1893 "Invalid argument error: Offsets must be non-negative and within the length of the Array"
1894 );
1895
1896 let offsets = Some(vec![0, 1].into());
1897 let err =
1898 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1899
1900 assert_eq!(
1901 err.to_string(),
1902 "Invalid argument error: Type Ids and Offsets lengths must match"
1903 );
1904
1905 let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1906
1907 assert_eq!(
1908 err.to_string(),
1909 "Invalid argument error: Union fields length must match child arrays length"
1910 );
1911 }
1912
1913 #[test]
1914 fn test_logical_nulls_fast_paths() {
1915 let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1917
1918 assert_eq!(array.logical_nulls(), None);
1919
1920 let fields = UnionFields::try_new(
1921 [1, 3],
1922 [
1923 Field::new("a", DataType::Int8, false), Field::new("b", DataType::Int8, false), ],
1926 )
1927 .unwrap();
1928 let array = UnionArray::try_new(
1929 fields,
1930 vec![1].into(),
1931 None,
1932 vec![
1933 Arc::new(Int8Array::from_value(5, 1)),
1934 Arc::new(Int8Array::from_value(5, 1)),
1935 ],
1936 )
1937 .unwrap();
1938
1939 assert_eq!(array.logical_nulls(), None);
1940
1941 let nullable_fields = UnionFields::try_new(
1942 [1, 3],
1943 [
1944 Field::new("a", DataType::Int8, true), Field::new("b", DataType::Int8, true), ],
1947 )
1948 .unwrap();
1949 let array = UnionArray::try_new(
1950 nullable_fields.clone(),
1951 vec![1, 1].into(),
1952 None,
1953 vec![
1954 Arc::new(Int8Array::from_value(-5, 2)), Arc::new(Int8Array::from_value(-5, 2)), ],
1957 )
1958 .unwrap();
1959
1960 assert_eq!(array.logical_nulls(), None);
1961
1962 let array = UnionArray::try_new(
1963 nullable_fields.clone(),
1964 vec![1, 1].into(),
1965 None,
1966 vec![
1967 Arc::new(Int8Array::new_null(2)), Arc::new(Int8Array::new_null(2)), ],
1971 )
1972 .unwrap();
1973
1974 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1975
1976 let array = UnionArray::try_new(
1977 nullable_fields.clone(),
1978 vec![1, 1].into(),
1979 Some(vec![0, 1].into()),
1980 vec![
1981 Arc::new(Int8Array::new_null(3)), Arc::new(Int8Array::new_null(3)), ],
1985 )
1986 .unwrap();
1987
1988 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1989 }
1990
1991 #[test]
1992 fn test_dense_union_logical_nulls_gather() {
1993 let int_array = Int32Array::from(vec![1, 2]);
1995 let float_array = Float64Array::from(vec![Some(3.2), None]);
1996 let str_array = StringArray::new_null(1);
1997 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
1998 let offsets = [0, 1, 0, 1, 0, 0]
1999 .into_iter()
2000 .collect::<ScalarBuffer<i32>>();
2001
2002 let children = vec![
2003 Arc::new(int_array) as Arc<dyn Array>,
2004 Arc::new(float_array),
2005 Arc::new(str_array),
2006 ];
2007
2008 let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
2009
2010 let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
2011
2012 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2013 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2014 }
2015
2016 #[test]
2017 fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
2018 let fields: UnionFields = [
2019 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2020 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2021 ]
2022 .into_iter()
2023 .collect();
2024
2025 let int_array = Int32Array::new_null(4);
2027 let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
2028 let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
2029
2030 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2031
2032 let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
2033
2034 let expected = BooleanBuffer::from(vec![false, false, true, false]);
2035
2036 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2037 assert_eq!(
2038 expected,
2039 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2040 );
2041
2042 let len = 2 * 64 + 32;
2044
2045 let int_array = Int32Array::new_null(len);
2046 let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
2047 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2048
2049 let array = UnionArray::try_new(
2050 fields,
2051 type_ids,
2052 None,
2053 vec![Arc::new(int_array), Arc::new(float_array)],
2054 )
2055 .unwrap();
2056
2057 let expected =
2058 BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2059
2060 assert_eq!(array.len(), len);
2061 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2062 assert_eq!(
2063 expected,
2064 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2065 );
2066 }
2067
2068 #[test]
2069 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2070 let int_array = Int32Array::from_value(2, 6);
2072 let float_array = Float64Array::from_value(4.2, 6);
2073 let str_array = StringArray::new_null(6);
2074 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2075
2076 let children = vec![
2077 Arc::new(int_array) as Arc<dyn Array>,
2078 Arc::new(float_array),
2079 Arc::new(str_array),
2080 ];
2081
2082 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2083
2084 let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2085
2086 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2087 assert_eq!(
2088 expected,
2089 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2090 );
2091
2092 let len = 2 * 64 + 32;
2094
2095 let int_array = Int32Array::from_value(2, len);
2096 let float_array = Float64Array::from_value(4.2, len);
2097 let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2098 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2099
2100 let children = vec![
2101 Arc::new(int_array) as Arc<dyn Array>,
2102 Arc::new(float_array),
2103 Arc::new(str_array),
2104 ];
2105
2106 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2107
2108 let expected = BooleanBuffer::from_iter(
2109 [true, true, true, true, false, true]
2110 .into_iter()
2111 .cycle()
2112 .take(len),
2113 );
2114
2115 assert_eq!(array.len(), len);
2116 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2117 assert_eq!(
2118 expected,
2119 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2120 );
2121 }
2122
2123 #[test]
2124 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2125 let int_array = Int32Array::new_null(6);
2127 let float_array = Float64Array::from_value(4.2, 6);
2128 let str_array = StringArray::new_null(6);
2129 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2130
2131 let children = vec![
2132 Arc::new(int_array) as Arc<dyn Array>,
2133 Arc::new(float_array),
2134 Arc::new(str_array),
2135 ];
2136
2137 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2138
2139 let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2140
2141 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2142 assert_eq!(
2143 expected,
2144 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2145 );
2146
2147 let len = 2 * 64 + 32;
2149
2150 let int_array = Int32Array::new_null(len);
2151 let float_array = Float64Array::from_value(4.2, len);
2152 let str_array = StringArray::new_null(len);
2153 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2154
2155 let children = vec![
2156 Arc::new(int_array) as Arc<dyn Array>,
2157 Arc::new(float_array),
2158 Arc::new(str_array),
2159 ];
2160
2161 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2162
2163 let expected = BooleanBuffer::from_iter(
2164 [false, false, true, true, false, false]
2165 .into_iter()
2166 .cycle()
2167 .take(len),
2168 );
2169
2170 assert_eq!(array.len(), len);
2171 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2172 assert_eq!(
2173 expected,
2174 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2175 );
2176 }
2177
2178 #[test]
2179 fn test_sparse_union_logical_nulls_gather() {
2180 let n_fields = 50;
2181
2182 let non_null = Int32Array::from_value(2, 4);
2183 let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2184 let fully_null = Int32Array::new_null(4);
2185
2186 let array = UnionArray::try_new(
2187 (1..)
2188 .step_by(2)
2189 .map(|i| {
2190 (
2191 i,
2192 Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2193 )
2194 })
2195 .take(n_fields)
2196 .collect(),
2197 vec![1, 3, 3, 5].into(),
2198 None,
2199 [
2200 Arc::new(non_null) as ArrayRef,
2201 Arc::new(mixed),
2202 Arc::new(fully_null),
2203 ]
2204 .into_iter()
2205 .cycle()
2206 .take(n_fields)
2207 .collect(),
2208 )
2209 .unwrap();
2210
2211 let expected = BooleanBuffer::from(vec![true, false, true, false]);
2212
2213 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2214 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2215 }
2216
2217 fn union_fields() -> UnionFields {
2218 [
2219 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2220 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2221 (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2222 ]
2223 .into_iter()
2224 .collect()
2225 }
2226
2227 #[test]
2228 fn test_is_nullable() {
2229 assert!(!create_union_array(false, false).is_nullable());
2230 assert!(create_union_array(true, false).is_nullable());
2231 assert!(create_union_array(false, true).is_nullable());
2232 assert!(create_union_array(true, true).is_nullable());
2233 }
2234
2235 fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2242 let int_array = if int_nullable {
2243 Int32Array::from(vec![Some(1), None, Some(3)])
2244 } else {
2245 Int32Array::from(vec![1, 2, 3])
2246 };
2247 let float_array = if float_nullable {
2248 Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2249 } else {
2250 Float64Array::from(vec![3.2, 4.2, 5.2])
2251 };
2252 let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2253 let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2254 let union_fields = [
2255 (0, Arc::new(Field::new("A", DataType::Int32, true))),
2256 (1, Arc::new(Field::new("B", DataType::Float64, true))),
2257 ]
2258 .into_iter()
2259 .collect::<UnionFields>();
2260
2261 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2262
2263 UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2264 }
2265}