1#![allow(clippy::enum_clike_unportable_variant)]
18
19use crate::{Array, ArrayRef, make_array};
20use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21use arrow_buffer::buffer::NullBuffer;
22use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25use std::any::Any;
28use std::collections::HashSet;
29use std::sync::Arc;
30
31#[derive(Clone)]
123pub struct UnionArray {
124 data_type: DataType,
125 type_ids: ScalarBuffer<i8>,
126 offsets: Option<ScalarBuffer<i32>>,
127 fields: Vec<Option<ArrayRef>>,
128}
129
130impl UnionArray {
131 pub unsafe fn new_unchecked(
150 fields: UnionFields,
151 type_ids: ScalarBuffer<i8>,
152 offsets: Option<ScalarBuffer<i32>>,
153 children: Vec<ArrayRef>,
154 ) -> Self {
155 let mode = if offsets.is_some() {
156 UnionMode::Dense
157 } else {
158 UnionMode::Sparse
159 };
160
161 let len = type_ids.len();
162 let builder = ArrayData::builder(DataType::Union(fields, mode))
163 .add_buffer(type_ids.into_inner())
164 .child_data(children.into_iter().map(Array::into_data).collect())
165 .len(len);
166
167 let data = match offsets {
168 Some(offsets) => unsafe { builder.add_buffer(offsets.into_inner()).build_unchecked() },
169 None => unsafe { builder.build_unchecked() },
170 };
171 Self::from(data)
172 }
173
174 pub fn try_new(
178 fields: UnionFields,
179 type_ids: ScalarBuffer<i8>,
180 offsets: Option<ScalarBuffer<i32>>,
181 children: Vec<ArrayRef>,
182 ) -> Result<Self, ArrowError> {
183 if fields.len() != children.len() {
185 return Err(ArrowError::InvalidArgumentError(
186 "Union fields length must match child arrays length".to_string(),
187 ));
188 }
189
190 if let Some(offsets) = &offsets {
191 if offsets.len() != type_ids.len() {
193 return Err(ArrowError::InvalidArgumentError(
194 "Type Ids and Offsets lengths must match".to_string(),
195 ));
196 }
197 } else {
198 for child in &children {
200 if child.len() != type_ids.len() {
201 return Err(ArrowError::InvalidArgumentError(
202 "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203 ));
204 }
205 }
206 }
207
208 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210 let mut array_lens = vec![i32::MIN; max_id + 1];
211 for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212 array_lens[field_id as usize] = cd.len() as i32;
213 }
214
215 for id in &type_ids {
217 match array_lens.get(*id as usize) {
218 Some(x) if *x != i32::MIN => {}
219 _ => {
220 return Err(ArrowError::InvalidArgumentError(
221 "Type Ids values must match one of the field type ids".to_owned(),
222 ));
223 }
224 }
225 }
226
227 if let Some(offsets) = &offsets {
229 let mut iter = type_ids.iter().zip(offsets.iter());
230 if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231 {
232 return Err(ArrowError::InvalidArgumentError(
233 "Offsets must be non-negative and within the length of the Array".to_owned(),
234 ));
235 }
236 }
237
238 let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241 Ok(union_array)
242 }
243
244 pub fn child(&self, type_id: i8) -> &ArrayRef {
251 assert!((type_id as usize) < self.fields.len());
252 let boxed = &self.fields[type_id as usize];
253 boxed.as_ref().expect("invalid type id")
254 }
255
256 pub fn type_id(&self, index: usize) -> i8 {
262 assert!(index < self.type_ids.len());
263 self.type_ids[index]
264 }
265
266 pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268 &self.type_ids
269 }
270
271 pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273 self.offsets.as_ref()
274 }
275
276 pub fn value_offset(&self, index: usize) -> usize {
282 assert!(index < self.len());
283 match &self.offsets {
284 Some(offsets) => offsets[index] as usize,
285 None => self.offset() + index,
286 }
287 }
288
289 pub fn value(&self, i: usize) -> ArrayRef {
297 let type_id = self.type_id(i);
298 let value_offset = self.value_offset(i);
299 let child = self.child(type_id);
300 child.slice(value_offset, 1)
301 }
302
303 pub fn type_names(&self) -> Vec<&str> {
305 match self.data_type() {
306 DataType::Union(fields, _) => fields
307 .iter()
308 .map(|(_, f)| f.name().as_str())
309 .collect::<Vec<&str>>(),
310 _ => unreachable!("Union array's data type is not a union!"),
311 }
312 }
313
314 pub fn fields(&self) -> &UnionFields {
316 match self.data_type() {
317 DataType::Union(fields, _) => fields,
318 _ => unreachable!("Union array's data type is not a union!"),
319 }
320 }
321
322 pub fn is_dense(&self) -> bool {
324 match self.data_type() {
325 DataType::Union(_, mode) => mode == &UnionMode::Dense,
326 _ => unreachable!("Union array's data type is not a union!"),
327 }
328 }
329
330 pub fn slice(&self, offset: usize, length: usize) -> Self {
332 let (offsets, fields) = match self.offsets.as_ref() {
333 Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
335 None => {
337 let fields = self
338 .fields
339 .iter()
340 .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
341 .collect();
342 (None, fields)
343 }
344 };
345
346 Self {
347 data_type: self.data_type.clone(),
348 type_ids: self.type_ids.slice(offset, length),
349 offsets,
350 fields,
351 }
352 }
353
354 #[allow(clippy::type_complexity)]
382 pub fn into_parts(
383 self,
384 ) -> (
385 UnionFields,
386 ScalarBuffer<i8>,
387 Option<ScalarBuffer<i32>>,
388 Vec<ArrayRef>,
389 ) {
390 let Self {
391 data_type,
392 type_ids,
393 offsets,
394 mut fields,
395 } = self;
396 match data_type {
397 DataType::Union(union_fields, _) => {
398 let children = union_fields
399 .iter()
400 .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
401 .collect();
402 (union_fields, type_ids, offsets, children)
403 }
404 _ => unreachable!(),
405 }
406 }
407
408 fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
410 let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
416 (
417 with_nulls_selected | is_field,
418 union_nulls | (is_field & field_nulls),
419 )
420 };
421
422 self.mask_sparse_helper(
423 nulls,
424 |type_ids_chunk_array, nulls_masks_iters| {
425 let (with_nulls_selected, union_nulls) = nulls_masks_iters
426 .iter_mut()
427 .map(|(field_type_id, field_nulls)| {
428 let field_nulls = field_nulls.next().unwrap();
429 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
430
431 (is_field, field_nulls)
432 })
433 .fold((0, 0), fold);
434
435 let without_nulls_selected = !with_nulls_selected;
437
438 without_nulls_selected | union_nulls
441 },
442 |type_ids_remainder, bit_chunks| {
443 let (with_nulls_selected, union_nulls) = bit_chunks
444 .iter()
445 .map(|(field_type_id, field_bit_chunks)| {
446 let field_nulls = field_bit_chunks.remainder_bits();
447 let is_field = selection_mask(type_ids_remainder, *field_type_id);
448
449 (is_field, field_nulls)
450 })
451 .fold((0, 0), fold);
452
453 let without_nulls_selected = !with_nulls_selected;
454
455 without_nulls_selected | union_nulls
456 },
457 )
458 }
459
460 fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
462 let fields = match self.data_type() {
463 DataType::Union(fields, _) => fields,
464 _ => unreachable!("Union array's data type is not a union!"),
465 };
466
467 let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
468 let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
469
470 let without_nulls_ids = type_ids
471 .difference(&with_nulls)
472 .copied()
473 .collect::<Vec<_>>();
474
475 nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
476
477 self.mask_sparse_helper(
482 nulls,
483 |type_ids_chunk_array, nulls_masks_iters| {
484 let union_nulls = nulls_masks_iters.iter_mut().fold(
485 0,
486 |union_nulls, (field_type_id, nulls_iter)| {
487 let field_nulls = nulls_iter.next().unwrap();
488
489 if field_nulls == 0 {
490 union_nulls
491 } else {
492 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
493
494 union_nulls | (is_field & field_nulls)
495 }
496 },
497 );
498
499 let without_nulls_selected =
501 without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
502
503 union_nulls | without_nulls_selected
506 },
507 |type_ids_remainder, bit_chunks| {
508 let union_nulls =
509 bit_chunks
510 .iter()
511 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
512 let is_field = selection_mask(type_ids_remainder, *field_type_id);
513 let field_nulls = field_bit_chunks.remainder_bits();
514
515 union_nulls | is_field & field_nulls
516 });
517
518 union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
519 },
520 )
521 }
522
523 fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
525 self.mask_sparse_helper(
532 nulls,
533 |type_ids_chunk_array, nulls_masks_iters| {
534 let (is_not_first, union_nulls) = nulls_masks_iters[1..] .iter_mut()
536 .fold(
537 (0, 0),
538 |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
539 let field_nulls = nulls_iter.next().unwrap();
540 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
541
542 (
543 is_not_first | is_field,
544 union_nulls | (is_field & field_nulls),
545 )
546 },
547 );
548
549 let is_first = !is_not_first;
550 let first_nulls = nulls_masks_iters[0].1.next().unwrap();
551
552 (is_first & first_nulls) | union_nulls
553 },
554 |type_ids_remainder, bit_chunks| {
555 bit_chunks
556 .iter()
557 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
558 let field_nulls = field_bit_chunks.remainder_bits();
559 let is_field = selection_mask(type_ids_remainder, *field_type_id);
562
563 union_nulls | (is_field & field_nulls)
564 })
565 },
566 )
567 }
568
569 fn mask_sparse_helper(
572 &self,
573 nulls: Vec<(i8, NullBuffer)>,
574 mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
575 mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
576 ) -> BooleanBuffer {
577 let bit_chunks = nulls
578 .iter()
579 .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
580 .collect::<Vec<_>>();
581
582 let mut nulls_masks_iter = bit_chunks
583 .iter()
584 .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
585 .collect::<Vec<_>>();
586
587 let chunks_exact = self.type_ids.chunks_exact(64);
588 let remainder = chunks_exact.remainder();
589
590 let chunks = chunks_exact.map(|type_ids_chunk| {
591 let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
592
593 mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
594 });
595
596 let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
599
600 if !remainder.is_empty() {
601 buffer.push(mask_remainder(remainder, &bit_chunks));
602 }
603
604 BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
605 }
606
607 fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
609 let one_null = NullBuffer::new_null(1);
610 let one_valid = NullBuffer::new_valid(1);
611
612 let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
619
620 for (type_id, nulls) in &nulls {
621 if nulls.null_count() == nulls.len() {
622 logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
624 } else {
625 logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
626 }
627 }
628
629 match &self.offsets {
630 Some(offsets) => {
631 assert_eq!(self.type_ids.len(), offsets.len());
632
633 BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
634 let type_id = *self.type_ids.get_unchecked(i);
636 let offset = *offsets.get_unchecked(i);
638
639 let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
640
641 nulls
647 .inner()
648 .value_unchecked(offset as usize & *offset_mask as usize)
649 })
650 }
651 None => {
652 BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
653 let type_id = *self.type_ids.get_unchecked(index);
655
656 let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
657
658 nulls.inner().value_unchecked(index & *index_mask as usize)
664 })
665 }
666 }
667 }
668
669 fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
672 self.fields
673 .iter()
674 .enumerate()
675 .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
676 .filter(|(_, nulls)| nulls.null_count() > 0)
677 .collect()
678 }
679}
680
681impl From<ArrayData> for UnionArray {
682 fn from(data: ArrayData) -> Self {
683 let (fields, mode) = match data.data_type() {
684 DataType::Union(fields, mode) => (fields, *mode),
685 d => panic!("UnionArray expected ArrayData with type Union got {d}"),
686 };
687 let (type_ids, offsets) = match mode {
688 UnionMode::Sparse => (
689 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
690 None,
691 ),
692 UnionMode::Dense => (
693 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
694 Some(ScalarBuffer::new(
695 data.buffers()[1].clone(),
696 data.offset(),
697 data.len(),
698 )),
699 ),
700 };
701
702 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
703 let mut boxed_fields = vec![None; max_id + 1];
704 for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
705 boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
706 }
707 Self {
708 data_type: data.data_type().clone(),
709 type_ids,
710 offsets,
711 fields: boxed_fields,
712 }
713 }
714}
715
716impl From<UnionArray> for ArrayData {
717 fn from(array: UnionArray) -> Self {
718 let len = array.len();
719 let f = match &array.data_type {
720 DataType::Union(f, _) => f,
721 _ => unreachable!(),
722 };
723 let buffers = match array.offsets {
724 Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
725 None => vec![array.type_ids.into_inner()],
726 };
727
728 let child = f
729 .iter()
730 .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
731 .collect();
732
733 let builder = ArrayDataBuilder::new(array.data_type)
734 .len(len)
735 .buffers(buffers)
736 .child_data(child);
737 unsafe { builder.build_unchecked() }
738 }
739}
740
741unsafe impl Array for UnionArray {
743 fn as_any(&self) -> &dyn Any {
744 self
745 }
746
747 fn to_data(&self) -> ArrayData {
748 self.clone().into()
749 }
750
751 fn into_data(self) -> ArrayData {
752 self.into()
753 }
754
755 fn data_type(&self) -> &DataType {
756 &self.data_type
757 }
758
759 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
760 Arc::new(self.slice(offset, length))
761 }
762
763 fn len(&self) -> usize {
764 self.type_ids.len()
765 }
766
767 fn is_empty(&self) -> bool {
768 self.type_ids.is_empty()
769 }
770
771 fn shrink_to_fit(&mut self) {
772 self.type_ids.shrink_to_fit();
773 if let Some(offsets) = &mut self.offsets {
774 offsets.shrink_to_fit();
775 }
776 for array in self.fields.iter_mut().flatten() {
777 array.shrink_to_fit();
778 }
779 self.fields.shrink_to_fit();
780 }
781
782 fn offset(&self) -> usize {
783 0
784 }
785
786 fn nulls(&self) -> Option<&NullBuffer> {
787 None
788 }
789
790 fn logical_nulls(&self) -> Option<NullBuffer> {
791 let fields = match self.data_type() {
792 DataType::Union(fields, _) => fields,
793 _ => unreachable!(),
794 };
795
796 if fields.len() <= 1 {
797 return self.fields.iter().find_map(|field_opt| {
798 field_opt
799 .as_ref()
800 .and_then(|field| field.logical_nulls())
801 .map(|logical_nulls| {
802 if self.is_dense() {
803 self.gather_nulls(vec![(0, logical_nulls)]).into()
804 } else {
805 logical_nulls
806 }
807 })
808 });
809 }
810
811 let logical_nulls = self.fields_logical_nulls();
812
813 if logical_nulls.is_empty() {
814 return None;
815 }
816
817 let fully_null_count = logical_nulls
818 .iter()
819 .filter(|(_, nulls)| nulls.null_count() == nulls.len())
820 .count();
821
822 if fully_null_count == fields.len() {
823 if let Some((_, exactly_sized)) = logical_nulls
824 .iter()
825 .find(|(_, nulls)| nulls.len() == self.len())
826 {
827 return Some(exactly_sized.clone());
828 }
829
830 if let Some((_, bigger)) = logical_nulls
831 .iter()
832 .find(|(_, nulls)| nulls.len() > self.len())
833 {
834 return Some(bigger.slice(0, self.len()));
835 }
836
837 return Some(NullBuffer::new_null(self.len()));
838 }
839
840 let boolean_buffer = match &self.offsets {
841 Some(_) => self.gather_nulls(logical_nulls),
842 None => {
843 let gather_relative_cost = if cfg!(target_feature = "avx2") {
851 10
852 } else if cfg!(target_feature = "sse4.1") {
853 3
854 } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
855 2
857 } else {
858 0
862 };
863
864 let strategies = [
865 (SparseStrategy::Gather, gather_relative_cost, true),
866 (
867 SparseStrategy::MaskAllFieldsWithNullsSkipOne,
868 fields.len() - 1,
869 fields.len() == logical_nulls.len(),
870 ),
871 (
872 SparseStrategy::MaskSkipWithoutNulls,
873 logical_nulls.len(),
874 true,
875 ),
876 (
877 SparseStrategy::MaskSkipFullyNull,
878 fields.len() - fully_null_count,
879 true,
880 ),
881 ];
882
883 let (strategy, _, _) = strategies
884 .iter()
885 .filter(|(_, _, applicable)| *applicable)
886 .min_by_key(|(_, cost, _)| cost)
887 .unwrap();
888
889 match strategy {
890 SparseStrategy::Gather => self.gather_nulls(logical_nulls),
891 SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
892 self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
893 }
894 SparseStrategy::MaskSkipWithoutNulls => {
895 self.mask_sparse_skip_without_nulls(logical_nulls)
896 }
897 SparseStrategy::MaskSkipFullyNull => {
898 self.mask_sparse_skip_fully_null(logical_nulls)
899 }
900 }
901 }
902 };
903
904 let null_buffer = NullBuffer::from(boolean_buffer);
905
906 if null_buffer.null_count() > 0 {
907 Some(null_buffer)
908 } else {
909 None
910 }
911 }
912
913 fn is_nullable(&self) -> bool {
914 self.fields
915 .iter()
916 .flatten()
917 .any(|field| field.is_nullable())
918 }
919
920 fn get_buffer_memory_size(&self) -> usize {
921 let mut sum = self.type_ids.inner().capacity();
922 if let Some(o) = self.offsets.as_ref() {
923 sum += o.inner().capacity()
924 }
925 self.fields
926 .iter()
927 .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
928 .sum::<usize>()
929 + sum
930 }
931
932 fn get_array_memory_size(&self) -> usize {
933 let mut sum = self.type_ids.inner().capacity();
934 if let Some(o) = self.offsets.as_ref() {
935 sum += o.inner().capacity()
936 }
937 std::mem::size_of::<Self>()
938 + self
939 .fields
940 .iter()
941 .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
942 .sum::<usize>()
943 + sum
944 }
945}
946
947impl std::fmt::Debug for UnionArray {
948 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
949 let header = if self.is_dense() {
950 "UnionArray(Dense)\n["
951 } else {
952 "UnionArray(Sparse)\n["
953 };
954 writeln!(f, "{header}")?;
955
956 writeln!(f, "-- type id buffer:")?;
957 writeln!(f, "{:?}", self.type_ids)?;
958
959 if let Some(offsets) = &self.offsets {
960 writeln!(f, "-- offsets buffer:")?;
961 writeln!(f, "{offsets:?}")?;
962 }
963
964 let fields = match self.data_type() {
965 DataType::Union(fields, _) => fields,
966 _ => unreachable!(),
967 };
968
969 for (type_id, field) in fields.iter() {
970 let child = self.child(type_id);
971 writeln!(
972 f,
973 "-- child {}: \"{}\" ({:?})",
974 type_id,
975 field.name(),
976 field.data_type()
977 )?;
978 std::fmt::Debug::fmt(child, f)?;
979 writeln!(f)?;
980 }
981 writeln!(f, "]")
982 }
983}
984
985enum SparseStrategy {
990 Gather,
992 MaskAllFieldsWithNullsSkipOne,
994 MaskSkipWithoutNulls,
996 MaskSkipFullyNull,
998}
999
1000#[derive(Copy, Clone)]
1001#[repr(usize)]
1002enum Mask {
1003 Zero = 0,
1004 #[allow(clippy::enum_clike_unportable_variant)]
1006 Max = usize::MAX,
1007}
1008
1009fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
1010 type_ids_chunk
1011 .iter()
1012 .copied()
1013 .enumerate()
1014 .fold(0, |packed, (bit_idx, v)| {
1015 packed | (((v == type_id) as u64) << bit_idx)
1016 })
1017}
1018
1019fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1021 without_nulls_ids
1022 .iter()
1023 .fold(0, |fully_valid_selected, field_type_id| {
1024 fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1025 })
1026}
1027
1028#[cfg(test)]
1029mod tests {
1030 use super::*;
1031 use std::collections::HashSet;
1032
1033 use crate::array::Int8Type;
1034 use crate::builder::UnionBuilder;
1035 use crate::cast::AsArray;
1036 use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1037 use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1038 use crate::{Int8Array, RecordBatch};
1039 use arrow_buffer::Buffer;
1040 use arrow_schema::{Field, Schema};
1041
1042 #[test]
1043 fn test_dense_i32() {
1044 let mut builder = UnionBuilder::new_dense();
1045 builder.append::<Int32Type>("a", 1).unwrap();
1046 builder.append::<Int32Type>("b", 2).unwrap();
1047 builder.append::<Int32Type>("c", 3).unwrap();
1048 builder.append::<Int32Type>("a", 4).unwrap();
1049 builder.append::<Int32Type>("c", 5).unwrap();
1050 builder.append::<Int32Type>("a", 6).unwrap();
1051 builder.append::<Int32Type>("b", 7).unwrap();
1052 let union = builder.build().unwrap();
1053
1054 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1055 let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1056 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1057
1058 assert_eq!(*union.type_ids(), expected_type_ids);
1060 for (i, id) in expected_type_ids.iter().enumerate() {
1061 assert_eq!(id, &union.type_id(i));
1062 }
1063
1064 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1066 for (i, id) in expected_offsets.iter().enumerate() {
1067 assert_eq!(union.value_offset(i), *id as usize);
1068 }
1069
1070 assert_eq!(
1072 *union.child(0).as_primitive::<Int32Type>().values(),
1073 [1_i32, 4, 6]
1074 );
1075 assert_eq!(
1076 *union.child(1).as_primitive::<Int32Type>().values(),
1077 [2_i32, 7]
1078 );
1079 assert_eq!(
1080 *union.child(2).as_primitive::<Int32Type>().values(),
1081 [3_i32, 5]
1082 );
1083
1084 assert_eq!(expected_array_values.len(), union.len());
1085 for (i, expected_value) in expected_array_values.iter().enumerate() {
1086 assert!(!union.is_null(i));
1087 let slot = union.value(i);
1088 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1089 assert_eq!(slot.len(), 1);
1090 let value = slot.value(0);
1091 assert_eq!(expected_value, &value);
1092 }
1093 }
1094
1095 #[test]
1096 fn slice_union_array_single_field() {
1097 let union_array = {
1100 let mut builder = UnionBuilder::new_dense();
1101 builder.append::<Int32Type>("a", 1).unwrap();
1102 builder.append_null::<Int32Type>("a").unwrap();
1103 builder.append::<Int32Type>("a", 3).unwrap();
1104 builder.append_null::<Int32Type>("a").unwrap();
1105 builder.append::<Int32Type>("a", 4).unwrap();
1106 builder.build().unwrap()
1107 };
1108
1109 let union_slice = union_array.slice(1, 3);
1111 let logical_nulls = union_slice.logical_nulls().unwrap();
1112
1113 assert_eq!(logical_nulls.len(), 3);
1114 assert!(logical_nulls.is_null(0));
1115 assert!(logical_nulls.is_valid(1));
1116 assert!(logical_nulls.is_null(2));
1117 }
1118
1119 #[test]
1120 #[cfg_attr(miri, ignore)]
1121 fn test_dense_i32_large() {
1122 let mut builder = UnionBuilder::new_dense();
1123
1124 let expected_type_ids = vec![0_i8; 1024];
1125 let expected_offsets: Vec<_> = (0..1024).collect();
1126 let expected_array_values: Vec<_> = (1..=1024).collect();
1127
1128 expected_array_values
1129 .iter()
1130 .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1131
1132 let union = builder.build().unwrap();
1133
1134 assert_eq!(*union.type_ids(), expected_type_ids);
1136 for (i, id) in expected_type_ids.iter().enumerate() {
1137 assert_eq!(id, &union.type_id(i));
1138 }
1139
1140 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1142 for (i, id) in expected_offsets.iter().enumerate() {
1143 assert_eq!(union.value_offset(i), *id as usize);
1144 }
1145
1146 for (i, expected_value) in expected_array_values.iter().enumerate() {
1147 assert!(!union.is_null(i));
1148 let slot = union.value(i);
1149 let slot = slot.as_primitive::<Int32Type>();
1150 assert_eq!(slot.len(), 1);
1151 let value = slot.value(0);
1152 assert_eq!(expected_value, &value);
1153 }
1154 }
1155
1156 #[test]
1157 fn test_dense_mixed() {
1158 let mut builder = UnionBuilder::new_dense();
1159 builder.append::<Int32Type>("a", 1).unwrap();
1160 builder.append::<Int64Type>("c", 3).unwrap();
1161 builder.append::<Int32Type>("a", 4).unwrap();
1162 builder.append::<Int64Type>("c", 5).unwrap();
1163 builder.append::<Int32Type>("a", 6).unwrap();
1164 let union = builder.build().unwrap();
1165
1166 assert_eq!(5, union.len());
1167 for i in 0..union.len() {
1168 let slot = union.value(i);
1169 assert!(!union.is_null(i));
1170 match i {
1171 0 => {
1172 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1173 assert_eq!(slot.len(), 1);
1174 let value = slot.value(0);
1175 assert_eq!(1_i32, value);
1176 }
1177 1 => {
1178 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1179 assert_eq!(slot.len(), 1);
1180 let value = slot.value(0);
1181 assert_eq!(3_i64, value);
1182 }
1183 2 => {
1184 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1185 assert_eq!(slot.len(), 1);
1186 let value = slot.value(0);
1187 assert_eq!(4_i32, value);
1188 }
1189 3 => {
1190 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1191 assert_eq!(slot.len(), 1);
1192 let value = slot.value(0);
1193 assert_eq!(5_i64, value);
1194 }
1195 4 => {
1196 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1197 assert_eq!(slot.len(), 1);
1198 let value = slot.value(0);
1199 assert_eq!(6_i32, value);
1200 }
1201 _ => unreachable!(),
1202 }
1203 }
1204 }
1205
1206 #[test]
1207 fn test_dense_mixed_with_nulls() {
1208 let mut builder = UnionBuilder::new_dense();
1209 builder.append::<Int32Type>("a", 1).unwrap();
1210 builder.append::<Int64Type>("c", 3).unwrap();
1211 builder.append::<Int32Type>("a", 10).unwrap();
1212 builder.append_null::<Int32Type>("a").unwrap();
1213 builder.append::<Int32Type>("a", 6).unwrap();
1214 let union = builder.build().unwrap();
1215
1216 assert_eq!(5, union.len());
1217 for i in 0..union.len() {
1218 let slot = union.value(i);
1219 match i {
1220 0 => {
1221 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1222 assert!(!slot.is_null(0));
1223 assert_eq!(slot.len(), 1);
1224 let value = slot.value(0);
1225 assert_eq!(1_i32, value);
1226 }
1227 1 => {
1228 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1229 assert!(!slot.is_null(0));
1230 assert_eq!(slot.len(), 1);
1231 let value = slot.value(0);
1232 assert_eq!(3_i64, value);
1233 }
1234 2 => {
1235 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1236 assert!(!slot.is_null(0));
1237 assert_eq!(slot.len(), 1);
1238 let value = slot.value(0);
1239 assert_eq!(10_i32, value);
1240 }
1241 3 => assert!(slot.is_null(0)),
1242 4 => {
1243 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1244 assert!(!slot.is_null(0));
1245 assert_eq!(slot.len(), 1);
1246 let value = slot.value(0);
1247 assert_eq!(6_i32, value);
1248 }
1249 _ => unreachable!(),
1250 }
1251 }
1252 }
1253
1254 #[test]
1255 fn test_dense_mixed_with_nulls_and_offset() {
1256 let mut builder = UnionBuilder::new_dense();
1257 builder.append::<Int32Type>("a", 1).unwrap();
1258 builder.append::<Int64Type>("c", 3).unwrap();
1259 builder.append::<Int32Type>("a", 10).unwrap();
1260 builder.append_null::<Int32Type>("a").unwrap();
1261 builder.append::<Int32Type>("a", 6).unwrap();
1262 let union = builder.build().unwrap();
1263
1264 let slice = union.slice(2, 3);
1265 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1266
1267 assert_eq!(3, new_union.len());
1268 for i in 0..new_union.len() {
1269 let slot = new_union.value(i);
1270 match i {
1271 0 => {
1272 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1273 assert!(!slot.is_null(0));
1274 assert_eq!(slot.len(), 1);
1275 let value = slot.value(0);
1276 assert_eq!(10_i32, value);
1277 }
1278 1 => assert!(slot.is_null(0)),
1279 2 => {
1280 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1281 assert!(!slot.is_null(0));
1282 assert_eq!(slot.len(), 1);
1283 let value = slot.value(0);
1284 assert_eq!(6_i32, value);
1285 }
1286 _ => unreachable!(),
1287 }
1288 }
1289 }
1290
1291 #[test]
1292 fn test_dense_mixed_with_str() {
1293 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1294 let int_array = Int32Array::from(vec![5, 6]);
1295 let float_array = Float64Array::from(vec![10.0]);
1296
1297 let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1298 let offsets = [0, 0, 1, 0, 2, 1]
1299 .into_iter()
1300 .collect::<ScalarBuffer<i32>>();
1301
1302 let fields = [
1303 (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1304 (1, Arc::new(Field::new("B", DataType::Int32, false))),
1305 (2, Arc::new(Field::new("C", DataType::Float64, false))),
1306 ]
1307 .into_iter()
1308 .collect::<UnionFields>();
1309 let children = [
1310 Arc::new(string_array) as Arc<dyn Array>,
1311 Arc::new(int_array),
1312 Arc::new(float_array),
1313 ]
1314 .into_iter()
1315 .collect();
1316 let array =
1317 UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1318
1319 assert_eq!(*array.type_ids(), type_ids);
1321 for (i, id) in type_ids.iter().enumerate() {
1322 assert_eq!(id, &array.type_id(i));
1323 }
1324
1325 assert_eq!(*array.offsets().unwrap(), offsets);
1327 for (i, id) in offsets.iter().enumerate() {
1328 assert_eq!(*id as usize, array.value_offset(i));
1329 }
1330
1331 assert_eq!(6, array.len());
1333
1334 let slot = array.value(0);
1335 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1336 assert_eq!(5, value);
1337
1338 let slot = array.value(1);
1339 let value = slot
1340 .as_any()
1341 .downcast_ref::<StringArray>()
1342 .unwrap()
1343 .value(0);
1344 assert_eq!("foo", value);
1345
1346 let slot = array.value(2);
1347 let value = slot
1348 .as_any()
1349 .downcast_ref::<StringArray>()
1350 .unwrap()
1351 .value(0);
1352 assert_eq!("bar", value);
1353
1354 let slot = array.value(3);
1355 let value = slot
1356 .as_any()
1357 .downcast_ref::<Float64Array>()
1358 .unwrap()
1359 .value(0);
1360 assert_eq!(10.0, value);
1361
1362 let slot = array.value(4);
1363 let value = slot
1364 .as_any()
1365 .downcast_ref::<StringArray>()
1366 .unwrap()
1367 .value(0);
1368 assert_eq!("baz", value);
1369
1370 let slot = array.value(5);
1371 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1372 assert_eq!(6, value);
1373 }
1374
1375 #[test]
1376 fn test_sparse_i32() {
1377 let mut builder = UnionBuilder::new_sparse();
1378 builder.append::<Int32Type>("a", 1).unwrap();
1379 builder.append::<Int32Type>("b", 2).unwrap();
1380 builder.append::<Int32Type>("c", 3).unwrap();
1381 builder.append::<Int32Type>("a", 4).unwrap();
1382 builder.append::<Int32Type>("c", 5).unwrap();
1383 builder.append::<Int32Type>("a", 6).unwrap();
1384 builder.append::<Int32Type>("b", 7).unwrap();
1385 let union = builder.build().unwrap();
1386
1387 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1388 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1389
1390 assert_eq!(*union.type_ids(), expected_type_ids);
1392 for (i, id) in expected_type_ids.iter().enumerate() {
1393 assert_eq!(id, &union.type_id(i));
1394 }
1395
1396 assert!(union.offsets().is_none());
1398
1399 assert_eq!(
1401 *union.child(0).as_primitive::<Int32Type>().values(),
1402 [1_i32, 0, 0, 4, 0, 6, 0],
1403 );
1404 assert_eq!(
1405 *union.child(1).as_primitive::<Int32Type>().values(),
1406 [0_i32, 2_i32, 0, 0, 0, 0, 7]
1407 );
1408 assert_eq!(
1409 *union.child(2).as_primitive::<Int32Type>().values(),
1410 [0_i32, 0, 3_i32, 0, 5, 0, 0]
1411 );
1412
1413 assert_eq!(expected_array_values.len(), union.len());
1414 for (i, expected_value) in expected_array_values.iter().enumerate() {
1415 assert!(!union.is_null(i));
1416 let slot = union.value(i);
1417 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1418 assert_eq!(slot.len(), 1);
1419 let value = slot.value(0);
1420 assert_eq!(expected_value, &value);
1421 }
1422 }
1423
1424 #[test]
1425 fn test_sparse_mixed() {
1426 let mut builder = UnionBuilder::new_sparse();
1427 builder.append::<Int32Type>("a", 1).unwrap();
1428 builder.append::<Float64Type>("c", 3.0).unwrap();
1429 builder.append::<Int32Type>("a", 4).unwrap();
1430 builder.append::<Float64Type>("c", 5.0).unwrap();
1431 builder.append::<Int32Type>("a", 6).unwrap();
1432 let union = builder.build().unwrap();
1433
1434 let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1435
1436 assert_eq!(*union.type_ids(), expected_type_ids);
1438 for (i, id) in expected_type_ids.iter().enumerate() {
1439 assert_eq!(id, &union.type_id(i));
1440 }
1441
1442 assert!(union.offsets().is_none());
1444
1445 for i in 0..union.len() {
1446 let slot = union.value(i);
1447 assert!(!union.is_null(i));
1448 match i {
1449 0 => {
1450 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1451 assert_eq!(slot.len(), 1);
1452 let value = slot.value(0);
1453 assert_eq!(1_i32, value);
1454 }
1455 1 => {
1456 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1457 assert_eq!(slot.len(), 1);
1458 let value = slot.value(0);
1459 assert_eq!(value, 3_f64);
1460 }
1461 2 => {
1462 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1463 assert_eq!(slot.len(), 1);
1464 let value = slot.value(0);
1465 assert_eq!(4_i32, value);
1466 }
1467 3 => {
1468 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1469 assert_eq!(slot.len(), 1);
1470 let value = slot.value(0);
1471 assert_eq!(5_f64, value);
1472 }
1473 4 => {
1474 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1475 assert_eq!(slot.len(), 1);
1476 let value = slot.value(0);
1477 assert_eq!(6_i32, value);
1478 }
1479 _ => unreachable!(),
1480 }
1481 }
1482 }
1483
1484 #[test]
1485 fn test_sparse_mixed_with_nulls() {
1486 let mut builder = UnionBuilder::new_sparse();
1487 builder.append::<Int32Type>("a", 1).unwrap();
1488 builder.append_null::<Int32Type>("a").unwrap();
1489 builder.append::<Float64Type>("c", 3.0).unwrap();
1490 builder.append::<Int32Type>("a", 4).unwrap();
1491 let union = builder.build().unwrap();
1492
1493 let expected_type_ids = vec![0_i8, 0, 1, 0];
1494
1495 assert_eq!(*union.type_ids(), expected_type_ids);
1497 for (i, id) in expected_type_ids.iter().enumerate() {
1498 assert_eq!(id, &union.type_id(i));
1499 }
1500
1501 assert!(union.offsets().is_none());
1503
1504 for i in 0..union.len() {
1505 let slot = union.value(i);
1506 match i {
1507 0 => {
1508 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1509 assert!(!slot.is_null(0));
1510 assert_eq!(slot.len(), 1);
1511 let value = slot.value(0);
1512 assert_eq!(1_i32, value);
1513 }
1514 1 => assert!(slot.is_null(0)),
1515 2 => {
1516 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1517 assert!(!slot.is_null(0));
1518 assert_eq!(slot.len(), 1);
1519 let value = slot.value(0);
1520 assert_eq!(value, 3_f64);
1521 }
1522 3 => {
1523 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1524 assert!(!slot.is_null(0));
1525 assert_eq!(slot.len(), 1);
1526 let value = slot.value(0);
1527 assert_eq!(4_i32, value);
1528 }
1529 _ => unreachable!(),
1530 }
1531 }
1532 }
1533
1534 #[test]
1535 fn test_sparse_mixed_with_nulls_and_offset() {
1536 let mut builder = UnionBuilder::new_sparse();
1537 builder.append::<Int32Type>("a", 1).unwrap();
1538 builder.append_null::<Int32Type>("a").unwrap();
1539 builder.append::<Float64Type>("c", 3.0).unwrap();
1540 builder.append_null::<Float64Type>("c").unwrap();
1541 builder.append::<Int32Type>("a", 4).unwrap();
1542 let union = builder.build().unwrap();
1543
1544 let slice = union.slice(1, 4);
1545 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1546
1547 assert_eq!(4, new_union.len());
1548 for i in 0..new_union.len() {
1549 let slot = new_union.value(i);
1550 match i {
1551 0 => assert!(slot.is_null(0)),
1552 1 => {
1553 let slot = slot.as_primitive::<Float64Type>();
1554 assert!(!slot.is_null(0));
1555 assert_eq!(slot.len(), 1);
1556 let value = slot.value(0);
1557 assert_eq!(value, 3_f64);
1558 }
1559 2 => assert!(slot.is_null(0)),
1560 3 => {
1561 let slot = slot.as_primitive::<Int32Type>();
1562 assert!(!slot.is_null(0));
1563 assert_eq!(slot.len(), 1);
1564 let value = slot.value(0);
1565 assert_eq!(4_i32, value);
1566 }
1567 _ => unreachable!(),
1568 }
1569 }
1570 }
1571
1572 fn test_union_validity(union_array: &UnionArray) {
1573 assert_eq!(union_array.null_count(), 0);
1574
1575 for i in 0..union_array.len() {
1576 assert!(!union_array.is_null(i));
1577 assert!(union_array.is_valid(i));
1578 }
1579 }
1580
1581 #[test]
1582 fn test_union_array_validity() {
1583 let mut builder = UnionBuilder::new_sparse();
1584 builder.append::<Int32Type>("a", 1).unwrap();
1585 builder.append_null::<Int32Type>("a").unwrap();
1586 builder.append::<Float64Type>("c", 3.0).unwrap();
1587 builder.append_null::<Float64Type>("c").unwrap();
1588 builder.append::<Int32Type>("a", 4).unwrap();
1589 let union = builder.build().unwrap();
1590
1591 test_union_validity(&union);
1592
1593 let mut builder = UnionBuilder::new_dense();
1594 builder.append::<Int32Type>("a", 1).unwrap();
1595 builder.append_null::<Int32Type>("a").unwrap();
1596 builder.append::<Float64Type>("c", 3.0).unwrap();
1597 builder.append_null::<Float64Type>("c").unwrap();
1598 builder.append::<Int32Type>("a", 4).unwrap();
1599 let union = builder.build().unwrap();
1600
1601 test_union_validity(&union);
1602 }
1603
1604 #[test]
1605 fn test_type_check() {
1606 let mut builder = UnionBuilder::new_sparse();
1607 builder.append::<Float32Type>("a", 1.0).unwrap();
1608 let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1609 assert!(
1610 err.contains(
1611 "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1612 ),
1613 "{}",
1614 err
1615 );
1616 }
1617
1618 #[test]
1619 fn slice_union_array() {
1620 fn create_union(mut builder: UnionBuilder) -> UnionArray {
1622 builder.append::<Int32Type>("a", 1).unwrap();
1623 builder.append_null::<Int32Type>("a").unwrap();
1624 builder.append::<Float64Type>("c", 3.0).unwrap();
1625 builder.append_null::<Float64Type>("c").unwrap();
1626 builder.append::<Int32Type>("a", 4).unwrap();
1627 builder.build().unwrap()
1628 }
1629
1630 fn create_batch(union: UnionArray) -> RecordBatch {
1631 let schema = Schema::new(vec![Field::new(
1632 "struct_array",
1633 union.data_type().clone(),
1634 true,
1635 )]);
1636
1637 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1638 }
1639
1640 fn test_slice_union(record_batch_slice: RecordBatch) {
1641 let union_slice = record_batch_slice
1642 .column(0)
1643 .as_any()
1644 .downcast_ref::<UnionArray>()
1645 .unwrap();
1646
1647 assert_eq!(union_slice.type_id(0), 0);
1648 assert_eq!(union_slice.type_id(1), 1);
1649 assert_eq!(union_slice.type_id(2), 1);
1650
1651 let slot = union_slice.value(0);
1652 let array = slot.as_primitive::<Int32Type>();
1653 assert_eq!(array.len(), 1);
1654 assert!(array.is_null(0));
1655
1656 let slot = union_slice.value(1);
1657 let array = slot.as_primitive::<Float64Type>();
1658 assert_eq!(array.len(), 1);
1659 assert!(array.is_valid(0));
1660 assert_eq!(array.value(0), 3.0);
1661
1662 let slot = union_slice.value(2);
1663 let array = slot.as_primitive::<Float64Type>();
1664 assert_eq!(array.len(), 1);
1665 assert!(array.is_null(0));
1666 }
1667
1668 let builder = UnionBuilder::new_sparse();
1670 let record_batch = create_batch(create_union(builder));
1671 let record_batch_slice = record_batch.slice(1, 3);
1673 test_slice_union(record_batch_slice);
1674
1675 let builder = UnionBuilder::new_dense();
1677 let record_batch = create_batch(create_union(builder));
1678 let record_batch_slice = record_batch.slice(1, 3);
1680 test_slice_union(record_batch_slice);
1681 }
1682
1683 #[test]
1684 fn test_custom_type_ids() {
1685 let data_type = DataType::Union(
1686 UnionFields::try_new(
1687 vec![8, 4, 9],
1688 vec![
1689 Field::new("strings", DataType::Utf8, false),
1690 Field::new("integers", DataType::Int32, false),
1691 Field::new("floats", DataType::Float64, false),
1692 ],
1693 )
1694 .unwrap(),
1695 UnionMode::Dense,
1696 );
1697
1698 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1699 let int_array = Int32Array::from(vec![5, 6, 4]);
1700 let float_array = Float64Array::from(vec![10.0]);
1701
1702 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1703 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1704
1705 let data = ArrayData::builder(data_type)
1706 .len(7)
1707 .buffers(vec![type_ids, value_offsets])
1708 .child_data(vec![
1709 string_array.into_data(),
1710 int_array.into_data(),
1711 float_array.into_data(),
1712 ])
1713 .build()
1714 .unwrap();
1715
1716 let array = UnionArray::from(data);
1717
1718 let v = array.value(0);
1719 assert_eq!(v.data_type(), &DataType::Int32);
1720 assert_eq!(v.len(), 1);
1721 assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1722
1723 let v = array.value(1);
1724 assert_eq!(v.data_type(), &DataType::Utf8);
1725 assert_eq!(v.len(), 1);
1726 assert_eq!(v.as_string::<i32>().value(0), "foo");
1727
1728 let v = array.value(2);
1729 assert_eq!(v.data_type(), &DataType::Int32);
1730 assert_eq!(v.len(), 1);
1731 assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1732
1733 let v = array.value(3);
1734 assert_eq!(v.data_type(), &DataType::Utf8);
1735 assert_eq!(v.len(), 1);
1736 assert_eq!(v.as_string::<i32>().value(0), "bar");
1737
1738 let v = array.value(4);
1739 assert_eq!(v.data_type(), &DataType::Float64);
1740 assert_eq!(v.len(), 1);
1741 assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1742
1743 let v = array.value(5);
1744 assert_eq!(v.data_type(), &DataType::Int32);
1745 assert_eq!(v.len(), 1);
1746 assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1747
1748 let v = array.value(6);
1749 assert_eq!(v.data_type(), &DataType::Utf8);
1750 assert_eq!(v.len(), 1);
1751 assert_eq!(v.as_string::<i32>().value(0), "baz");
1752 }
1753
1754 #[test]
1755 fn into_parts() {
1756 let mut builder = UnionBuilder::new_dense();
1757 builder.append::<Int32Type>("a", 1).unwrap();
1758 builder.append::<Int8Type>("b", 2).unwrap();
1759 builder.append::<Int32Type>("a", 3).unwrap();
1760 let dense_union = builder.build().unwrap();
1761
1762 let field = [
1763 &Arc::new(Field::new("a", DataType::Int32, false)),
1764 &Arc::new(Field::new("b", DataType::Int8, false)),
1765 ];
1766 let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1767 assert_eq!(
1768 union_fields
1769 .iter()
1770 .map(|(_, field)| field)
1771 .collect::<Vec<_>>(),
1772 field
1773 );
1774 assert_eq!(type_ids, [0, 1, 0]);
1775 assert!(offsets.is_some());
1776 assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1777
1778 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1779 assert!(result.is_ok());
1780 assert_eq!(result.unwrap().len(), 3);
1781
1782 let mut builder = UnionBuilder::new_sparse();
1783 builder.append::<Int32Type>("a", 1).unwrap();
1784 builder.append::<Int8Type>("b", 2).unwrap();
1785 builder.append::<Int32Type>("a", 3).unwrap();
1786 let sparse_union = builder.build().unwrap();
1787
1788 let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1789 assert_eq!(type_ids, [0, 1, 0]);
1790 assert!(offsets.is_none());
1791
1792 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1793 assert!(result.is_ok());
1794 assert_eq!(result.unwrap().len(), 3);
1795 }
1796
1797 #[test]
1798 fn into_parts_custom_type_ids() {
1799 let set_field_type_ids: [i8; 3] = [8, 4, 9];
1800 let data_type = DataType::Union(
1801 UnionFields::try_new(
1802 set_field_type_ids,
1803 [
1804 Field::new("strings", DataType::Utf8, false),
1805 Field::new("integers", DataType::Int32, false),
1806 Field::new("floats", DataType::Float64, false),
1807 ],
1808 )
1809 .unwrap(),
1810 UnionMode::Dense,
1811 );
1812 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1813 let int_array = Int32Array::from(vec![5, 6, 4]);
1814 let float_array = Float64Array::from(vec![10.0]);
1815 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1816 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1817 let data = ArrayData::builder(data_type)
1818 .len(7)
1819 .buffers(vec![type_ids, value_offsets])
1820 .child_data(vec![
1821 string_array.into_data(),
1822 int_array.into_data(),
1823 float_array.into_data(),
1824 ])
1825 .build()
1826 .unwrap();
1827 let array = UnionArray::from(data);
1828
1829 let (union_fields, type_ids, offsets, children) = array.into_parts();
1830 assert_eq!(
1831 type_ids.iter().collect::<HashSet<_>>(),
1832 set_field_type_ids.iter().collect::<HashSet<_>>()
1833 );
1834 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1835 assert!(result.is_ok());
1836 let array = result.unwrap();
1837 assert_eq!(array.len(), 7);
1838 }
1839
1840 #[test]
1841 fn test_invalid() {
1842 let fields = UnionFields::try_new(
1843 [3, 2],
1844 [
1845 Field::new("a", DataType::Utf8, false),
1846 Field::new("b", DataType::Utf8, false),
1847 ],
1848 )
1849 .unwrap();
1850 let children = vec![
1851 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1852 Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1853 ];
1854
1855 let type_ids = vec![3, 3, 2].into();
1856 let err =
1857 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1858 assert_eq!(
1859 err.to_string(),
1860 "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1861 );
1862
1863 let type_ids = vec![1, 2].into();
1864 let err =
1865 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1866 assert_eq!(
1867 err.to_string(),
1868 "Invalid argument error: Type Ids values must match one of the field type ids"
1869 );
1870
1871 let type_ids = vec![7, 2].into();
1872 let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1873 assert_eq!(
1874 err.to_string(),
1875 "Invalid argument error: Type Ids values must match one of the field type ids"
1876 );
1877
1878 let children = vec![
1879 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1880 Arc::new(StringArray::from_iter_values(["c"])) as _,
1881 ];
1882 let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1883 let offsets = Some(vec![0, 1, 0].into());
1884 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1885
1886 let offsets = Some(vec![0, 1, 1].into());
1887 let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1888 .unwrap_err();
1889
1890 assert_eq!(
1891 err.to_string(),
1892 "Invalid argument error: Offsets must be non-negative and within the length of the Array"
1893 );
1894
1895 let offsets = Some(vec![0, 1].into());
1896 let err =
1897 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1898
1899 assert_eq!(
1900 err.to_string(),
1901 "Invalid argument error: Type Ids and Offsets lengths must match"
1902 );
1903
1904 let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1905
1906 assert_eq!(
1907 err.to_string(),
1908 "Invalid argument error: Union fields length must match child arrays length"
1909 );
1910 }
1911
1912 #[test]
1913 fn test_logical_nulls_fast_paths() {
1914 let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1916
1917 assert_eq!(array.logical_nulls(), None);
1918
1919 let fields = UnionFields::try_new(
1920 [1, 3],
1921 [
1922 Field::new("a", DataType::Int8, false), Field::new("b", DataType::Int8, false), ],
1925 )
1926 .unwrap();
1927 let array = UnionArray::try_new(
1928 fields,
1929 vec![1].into(),
1930 None,
1931 vec![
1932 Arc::new(Int8Array::from_value(5, 1)),
1933 Arc::new(Int8Array::from_value(5, 1)),
1934 ],
1935 )
1936 .unwrap();
1937
1938 assert_eq!(array.logical_nulls(), None);
1939
1940 let nullable_fields = UnionFields::try_new(
1941 [1, 3],
1942 [
1943 Field::new("a", DataType::Int8, true), Field::new("b", DataType::Int8, true), ],
1946 )
1947 .unwrap();
1948 let array = UnionArray::try_new(
1949 nullable_fields.clone(),
1950 vec![1, 1].into(),
1951 None,
1952 vec![
1953 Arc::new(Int8Array::from_value(-5, 2)), Arc::new(Int8Array::from_value(-5, 2)), ],
1956 )
1957 .unwrap();
1958
1959 assert_eq!(array.logical_nulls(), None);
1960
1961 let array = UnionArray::try_new(
1962 nullable_fields.clone(),
1963 vec![1, 1].into(),
1964 None,
1965 vec![
1966 Arc::new(Int8Array::new_null(2)), Arc::new(Int8Array::new_null(2)), ],
1970 )
1971 .unwrap();
1972
1973 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1974
1975 let array = UnionArray::try_new(
1976 nullable_fields.clone(),
1977 vec![1, 1].into(),
1978 Some(vec![0, 1].into()),
1979 vec![
1980 Arc::new(Int8Array::new_null(3)), Arc::new(Int8Array::new_null(3)), ],
1984 )
1985 .unwrap();
1986
1987 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1988 }
1989
1990 #[test]
1991 fn test_dense_union_logical_nulls_gather() {
1992 let int_array = Int32Array::from(vec![1, 2]);
1994 let float_array = Float64Array::from(vec![Some(3.2), None]);
1995 let str_array = StringArray::new_null(1);
1996 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
1997 let offsets = [0, 1, 0, 1, 0, 0]
1998 .into_iter()
1999 .collect::<ScalarBuffer<i32>>();
2000
2001 let children = vec![
2002 Arc::new(int_array) as Arc<dyn Array>,
2003 Arc::new(float_array),
2004 Arc::new(str_array),
2005 ];
2006
2007 let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
2008
2009 let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
2010
2011 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2012 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2013 }
2014
2015 #[test]
2016 fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
2017 let fields: UnionFields = [
2018 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2019 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2020 ]
2021 .into_iter()
2022 .collect();
2023
2024 let int_array = Int32Array::new_null(4);
2026 let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
2027 let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
2028
2029 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2030
2031 let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
2032
2033 let expected = BooleanBuffer::from(vec![false, false, true, false]);
2034
2035 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2036 assert_eq!(
2037 expected,
2038 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2039 );
2040
2041 let len = 2 * 64 + 32;
2043
2044 let int_array = Int32Array::new_null(len);
2045 let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
2046 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2047
2048 let array = UnionArray::try_new(
2049 fields,
2050 type_ids,
2051 None,
2052 vec![Arc::new(int_array), Arc::new(float_array)],
2053 )
2054 .unwrap();
2055
2056 let expected =
2057 BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2058
2059 assert_eq!(array.len(), len);
2060 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2061 assert_eq!(
2062 expected,
2063 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2064 );
2065 }
2066
2067 #[test]
2068 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2069 let int_array = Int32Array::from_value(2, 6);
2071 let float_array = Float64Array::from_value(4.2, 6);
2072 let str_array = StringArray::new_null(6);
2073 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2074
2075 let children = vec![
2076 Arc::new(int_array) as Arc<dyn Array>,
2077 Arc::new(float_array),
2078 Arc::new(str_array),
2079 ];
2080
2081 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2082
2083 let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2084
2085 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2086 assert_eq!(
2087 expected,
2088 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2089 );
2090
2091 let len = 2 * 64 + 32;
2093
2094 let int_array = Int32Array::from_value(2, len);
2095 let float_array = Float64Array::from_value(4.2, len);
2096 let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2097 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2098
2099 let children = vec![
2100 Arc::new(int_array) as Arc<dyn Array>,
2101 Arc::new(float_array),
2102 Arc::new(str_array),
2103 ];
2104
2105 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2106
2107 let expected = BooleanBuffer::from_iter(
2108 [true, true, true, true, false, true]
2109 .into_iter()
2110 .cycle()
2111 .take(len),
2112 );
2113
2114 assert_eq!(array.len(), len);
2115 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2116 assert_eq!(
2117 expected,
2118 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2119 );
2120 }
2121
2122 #[test]
2123 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2124 let int_array = Int32Array::new_null(6);
2126 let float_array = Float64Array::from_value(4.2, 6);
2127 let str_array = StringArray::new_null(6);
2128 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2129
2130 let children = vec![
2131 Arc::new(int_array) as Arc<dyn Array>,
2132 Arc::new(float_array),
2133 Arc::new(str_array),
2134 ];
2135
2136 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2137
2138 let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2139
2140 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2141 assert_eq!(
2142 expected,
2143 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2144 );
2145
2146 let len = 2 * 64 + 32;
2148
2149 let int_array = Int32Array::new_null(len);
2150 let float_array = Float64Array::from_value(4.2, len);
2151 let str_array = StringArray::new_null(len);
2152 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2153
2154 let children = vec![
2155 Arc::new(int_array) as Arc<dyn Array>,
2156 Arc::new(float_array),
2157 Arc::new(str_array),
2158 ];
2159
2160 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2161
2162 let expected = BooleanBuffer::from_iter(
2163 [false, false, true, true, false, false]
2164 .into_iter()
2165 .cycle()
2166 .take(len),
2167 );
2168
2169 assert_eq!(array.len(), len);
2170 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2171 assert_eq!(
2172 expected,
2173 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2174 );
2175 }
2176
2177 #[test]
2178 fn test_sparse_union_logical_nulls_gather() {
2179 let n_fields = 50;
2180
2181 let non_null = Int32Array::from_value(2, 4);
2182 let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2183 let fully_null = Int32Array::new_null(4);
2184
2185 let array = UnionArray::try_new(
2186 (1..)
2187 .step_by(2)
2188 .map(|i| {
2189 (
2190 i,
2191 Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2192 )
2193 })
2194 .take(n_fields)
2195 .collect(),
2196 vec![1, 3, 3, 5].into(),
2197 None,
2198 [
2199 Arc::new(non_null) as ArrayRef,
2200 Arc::new(mixed),
2201 Arc::new(fully_null),
2202 ]
2203 .into_iter()
2204 .cycle()
2205 .take(n_fields)
2206 .collect(),
2207 )
2208 .unwrap();
2209
2210 let expected = BooleanBuffer::from(vec![true, false, true, false]);
2211
2212 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2213 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2214 }
2215
2216 fn union_fields() -> UnionFields {
2217 [
2218 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2219 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2220 (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2221 ]
2222 .into_iter()
2223 .collect()
2224 }
2225
2226 #[test]
2227 fn test_is_nullable() {
2228 assert!(!create_union_array(false, false).is_nullable());
2229 assert!(create_union_array(true, false).is_nullable());
2230 assert!(create_union_array(false, true).is_nullable());
2231 assert!(create_union_array(true, true).is_nullable());
2232 }
2233
2234 fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2241 let int_array = if int_nullable {
2242 Int32Array::from(vec![Some(1), None, Some(3)])
2243 } else {
2244 Int32Array::from(vec![1, 2, 3])
2245 };
2246 let float_array = if float_nullable {
2247 Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2248 } else {
2249 Float64Array::from(vec![3.2, 4.2, 5.2])
2250 };
2251 let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2252 let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2253 let union_fields = [
2254 (0, Arc::new(Field::new("A", DataType::Int32, true))),
2255 (1, Arc::new(Field::new("B", DataType::Float64, true))),
2256 ]
2257 .into_iter()
2258 .collect::<UnionFields>();
2259
2260 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2261
2262 UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2263 }
2264}