1#![allow(clippy::enum_clike_unportable_variant)]
18
19use crate::{make_array, Array, ArrayRef};
20use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks};
21use arrow_buffer::buffer::NullBuffer;
22use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer};
23use arrow_data::{ArrayData, ArrayDataBuilder};
24use arrow_schema::{ArrowError, DataType, UnionFields, UnionMode};
25use std::any::Any;
28use std::collections::HashSet;
29use std::sync::Arc;
30
31#[derive(Clone)]
123pub struct UnionArray {
124 data_type: DataType,
125 type_ids: ScalarBuffer<i8>,
126 offsets: Option<ScalarBuffer<i32>>,
127 fields: Vec<Option<ArrayRef>>,
128}
129
130impl UnionArray {
131 pub unsafe fn new_unchecked(
150 fields: UnionFields,
151 type_ids: ScalarBuffer<i8>,
152 offsets: Option<ScalarBuffer<i32>>,
153 children: Vec<ArrayRef>,
154 ) -> Self {
155 let mode = if offsets.is_some() {
156 UnionMode::Dense
157 } else {
158 UnionMode::Sparse
159 };
160
161 let len = type_ids.len();
162 let builder = ArrayData::builder(DataType::Union(fields, mode))
163 .add_buffer(type_ids.into_inner())
164 .child_data(children.into_iter().map(Array::into_data).collect())
165 .len(len);
166
167 let data = match offsets {
168 Some(offsets) => builder.add_buffer(offsets.into_inner()).build_unchecked(),
169 None => builder.build_unchecked(),
170 };
171 Self::from(data)
172 }
173
174 pub fn try_new(
178 fields: UnionFields,
179 type_ids: ScalarBuffer<i8>,
180 offsets: Option<ScalarBuffer<i32>>,
181 children: Vec<ArrayRef>,
182 ) -> Result<Self, ArrowError> {
183 if fields.len() != children.len() {
185 return Err(ArrowError::InvalidArgumentError(
186 "Union fields length must match child arrays length".to_string(),
187 ));
188 }
189
190 if let Some(offsets) = &offsets {
191 if offsets.len() != type_ids.len() {
193 return Err(ArrowError::InvalidArgumentError(
194 "Type Ids and Offsets lengths must match".to_string(),
195 ));
196 }
197 } else {
198 for child in &children {
200 if child.len() != type_ids.len() {
201 return Err(ArrowError::InvalidArgumentError(
202 "Sparse union child arrays must be equal in length to the length of the union".to_string(),
203 ));
204 }
205 }
206 }
207
208 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
210 let mut array_lens = vec![i32::MIN; max_id + 1];
211 for (cd, (field_id, _)) in children.iter().zip(fields.iter()) {
212 array_lens[field_id as usize] = cd.len() as i32;
213 }
214
215 for id in &type_ids {
217 match array_lens.get(*id as usize) {
218 Some(x) if *x != i32::MIN => {}
219 _ => {
220 return Err(ArrowError::InvalidArgumentError(
221 "Type Ids values must match one of the field type ids".to_owned(),
222 ))
223 }
224 }
225 }
226
227 if let Some(offsets) = &offsets {
229 let mut iter = type_ids.iter().zip(offsets.iter());
230 if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize])
231 {
232 return Err(ArrowError::InvalidArgumentError(
233 "Offsets must be positive and within the length of the Array".to_owned(),
234 ));
235 }
236 }
237
238 let union_array = unsafe { Self::new_unchecked(fields, type_ids, offsets, children) };
241 Ok(union_array)
242 }
243
244 pub fn child(&self, type_id: i8) -> &ArrayRef {
251 assert!((type_id as usize) < self.fields.len());
252 let boxed = &self.fields[type_id as usize];
253 boxed.as_ref().expect("invalid type id")
254 }
255
256 pub fn type_id(&self, index: usize) -> i8 {
262 assert!(index < self.type_ids.len());
263 self.type_ids[index]
264 }
265
266 pub fn type_ids(&self) -> &ScalarBuffer<i8> {
268 &self.type_ids
269 }
270
271 pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
273 self.offsets.as_ref()
274 }
275
276 pub fn value_offset(&self, index: usize) -> usize {
282 assert!(index < self.len());
283 match &self.offsets {
284 Some(offsets) => offsets[index] as usize,
285 None => self.offset() + index,
286 }
287 }
288
289 pub fn value(&self, i: usize) -> ArrayRef {
297 let type_id = self.type_id(i);
298 let value_offset = self.value_offset(i);
299 let child = self.child(type_id);
300 child.slice(value_offset, 1)
301 }
302
303 pub fn type_names(&self) -> Vec<&str> {
305 match self.data_type() {
306 DataType::Union(fields, _) => fields
307 .iter()
308 .map(|(_, f)| f.name().as_str())
309 .collect::<Vec<&str>>(),
310 _ => unreachable!("Union array's data type is not a union!"),
311 }
312 }
313
314 fn is_dense(&self) -> bool {
316 match self.data_type() {
317 DataType::Union(_, mode) => mode == &UnionMode::Dense,
318 _ => unreachable!("Union array's data type is not a union!"),
319 }
320 }
321
322 pub fn slice(&self, offset: usize, length: usize) -> Self {
324 let (offsets, fields) = match self.offsets.as_ref() {
325 Some(offsets) => (Some(offsets.slice(offset, length)), self.fields.clone()),
327 None => {
329 let fields = self
330 .fields
331 .iter()
332 .map(|x| x.as_ref().map(|x| x.slice(offset, length)))
333 .collect();
334 (None, fields)
335 }
336 };
337
338 Self {
339 data_type: self.data_type.clone(),
340 type_ids: self.type_ids.slice(offset, length),
341 offsets,
342 fields,
343 }
344 }
345
346 #[allow(clippy::type_complexity)]
374 pub fn into_parts(
375 self,
376 ) -> (
377 UnionFields,
378 ScalarBuffer<i8>,
379 Option<ScalarBuffer<i32>>,
380 Vec<ArrayRef>,
381 ) {
382 let Self {
383 data_type,
384 type_ids,
385 offsets,
386 mut fields,
387 } = self;
388 match data_type {
389 DataType::Union(union_fields, _) => {
390 let children = union_fields
391 .iter()
392 .map(|(type_id, _)| fields[type_id as usize].take().unwrap())
393 .collect();
394 (union_fields, type_ids, offsets, children)
395 }
396 _ => unreachable!(),
397 }
398 }
399
400 fn mask_sparse_skip_without_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
402 let fold = |(with_nulls_selected, union_nulls), (is_field, field_nulls)| {
408 (
409 with_nulls_selected | is_field,
410 union_nulls | (is_field & field_nulls),
411 )
412 };
413
414 self.mask_sparse_helper(
415 nulls,
416 |type_ids_chunk_array, nulls_masks_iters| {
417 let (with_nulls_selected, union_nulls) = nulls_masks_iters
418 .iter_mut()
419 .map(|(field_type_id, field_nulls)| {
420 let field_nulls = field_nulls.next().unwrap();
421 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
422
423 (is_field, field_nulls)
424 })
425 .fold((0, 0), fold);
426
427 let without_nulls_selected = !with_nulls_selected;
429
430 without_nulls_selected | union_nulls
433 },
434 |type_ids_remainder, bit_chunks| {
435 let (with_nulls_selected, union_nulls) = bit_chunks
436 .iter()
437 .map(|(field_type_id, field_bit_chunks)| {
438 let field_nulls = field_bit_chunks.remainder_bits();
439 let is_field = selection_mask(type_ids_remainder, *field_type_id);
440
441 (is_field, field_nulls)
442 })
443 .fold((0, 0), fold);
444
445 let without_nulls_selected = !with_nulls_selected;
446
447 without_nulls_selected | union_nulls
448 },
449 )
450 }
451
452 fn mask_sparse_skip_fully_null(&self, mut nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
454 let fields = match self.data_type() {
455 DataType::Union(fields, _) => fields,
456 _ => unreachable!("Union array's data type is not a union!"),
457 };
458
459 let type_ids = fields.iter().map(|(id, _)| id).collect::<HashSet<_>>();
460 let with_nulls = nulls.iter().map(|(id, _)| *id).collect::<HashSet<_>>();
461
462 let without_nulls_ids = type_ids
463 .difference(&with_nulls)
464 .copied()
465 .collect::<Vec<_>>();
466
467 nulls.retain(|(_, nulls)| nulls.null_count() < nulls.len());
468
469 self.mask_sparse_helper(
474 nulls,
475 |type_ids_chunk_array, nulls_masks_iters| {
476 let union_nulls = nulls_masks_iters.iter_mut().fold(
477 0,
478 |union_nulls, (field_type_id, nulls_iter)| {
479 let field_nulls = nulls_iter.next().unwrap();
480
481 if field_nulls == 0 {
482 union_nulls
483 } else {
484 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
485
486 union_nulls | (is_field & field_nulls)
487 }
488 },
489 );
490
491 let without_nulls_selected =
493 without_nulls_selected(type_ids_chunk_array, &without_nulls_ids);
494
495 union_nulls | without_nulls_selected
498 },
499 |type_ids_remainder, bit_chunks| {
500 let union_nulls =
501 bit_chunks
502 .iter()
503 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
504 let is_field = selection_mask(type_ids_remainder, *field_type_id);
505 let field_nulls = field_bit_chunks.remainder_bits();
506
507 union_nulls | is_field & field_nulls
508 });
509
510 union_nulls | without_nulls_selected(type_ids_remainder, &without_nulls_ids)
511 },
512 )
513 }
514
515 fn mask_sparse_all_with_nulls_skip_one(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
517 self.mask_sparse_helper(
524 nulls,
525 |type_ids_chunk_array, nulls_masks_iters| {
526 let (is_not_first, union_nulls) = nulls_masks_iters[1..] .iter_mut()
528 .fold(
529 (0, 0),
530 |(is_not_first, union_nulls), (field_type_id, nulls_iter)| {
531 let field_nulls = nulls_iter.next().unwrap();
532 let is_field = selection_mask(type_ids_chunk_array, *field_type_id);
533
534 (
535 is_not_first | is_field,
536 union_nulls | (is_field & field_nulls),
537 )
538 },
539 );
540
541 let is_first = !is_not_first;
542 let first_nulls = nulls_masks_iters[0].1.next().unwrap();
543
544 (is_first & first_nulls) | union_nulls
545 },
546 |type_ids_remainder, bit_chunks| {
547 bit_chunks
548 .iter()
549 .fold(0, |union_nulls, (field_type_id, field_bit_chunks)| {
550 let field_nulls = field_bit_chunks.remainder_bits();
551 let is_field = selection_mask(type_ids_remainder, *field_type_id);
554
555 union_nulls | (is_field & field_nulls)
556 })
557 },
558 )
559 }
560
561 fn mask_sparse_helper(
564 &self,
565 nulls: Vec<(i8, NullBuffer)>,
566 mut mask_chunk: impl FnMut(&[i8; 64], &mut [(i8, BitChunkIterator)]) -> u64,
567 mask_remainder: impl FnOnce(&[i8], &[(i8, BitChunks)]) -> u64,
568 ) -> BooleanBuffer {
569 let bit_chunks = nulls
570 .iter()
571 .map(|(type_id, nulls)| (*type_id, nulls.inner().bit_chunks()))
572 .collect::<Vec<_>>();
573
574 let mut nulls_masks_iter = bit_chunks
575 .iter()
576 .map(|(type_id, bit_chunks)| (*type_id, bit_chunks.iter()))
577 .collect::<Vec<_>>();
578
579 let chunks_exact = self.type_ids.chunks_exact(64);
580 let remainder = chunks_exact.remainder();
581
582 let chunks = chunks_exact.map(|type_ids_chunk| {
583 let type_ids_chunk_array = <&[i8; 64]>::try_from(type_ids_chunk).unwrap();
584
585 mask_chunk(type_ids_chunk_array, &mut nulls_masks_iter)
586 });
587
588 let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) };
591
592 if !remainder.is_empty() {
593 buffer.push(mask_remainder(remainder, &bit_chunks));
594 }
595
596 BooleanBuffer::new(buffer.into(), 0, self.type_ids.len())
597 }
598
599 fn gather_nulls(&self, nulls: Vec<(i8, NullBuffer)>) -> BooleanBuffer {
601 let one_null = NullBuffer::new_null(1);
602 let one_valid = NullBuffer::new_valid(1);
603
604 let mut logical_nulls_array = [(&one_valid, Mask::Zero); 256];
611
612 for (type_id, nulls) in &nulls {
613 if nulls.null_count() == nulls.len() {
614 logical_nulls_array[*type_id as u8 as usize] = (&one_null, Mask::Zero);
616 } else {
617 logical_nulls_array[*type_id as u8 as usize] = (nulls, Mask::Max);
618 }
619 }
620
621 match &self.offsets {
622 Some(offsets) => {
623 assert_eq!(self.type_ids.len(), offsets.len());
624
625 BooleanBuffer::collect_bool(self.type_ids.len(), |i| unsafe {
626 let type_id = *self.type_ids.get_unchecked(i);
628 let offset = *offsets.get_unchecked(i);
630
631 let (nulls, offset_mask) = &logical_nulls_array[type_id as u8 as usize];
632
633 nulls
639 .inner()
640 .value_unchecked(offset as usize & *offset_mask as usize)
641 })
642 }
643 None => {
644 BooleanBuffer::collect_bool(self.type_ids.len(), |index| unsafe {
645 let type_id = *self.type_ids.get_unchecked(index);
647
648 let (nulls, index_mask) = &logical_nulls_array[type_id as u8 as usize];
649
650 nulls.inner().value_unchecked(index & *index_mask as usize)
656 })
657 }
658 }
659 }
660
661 fn fields_logical_nulls(&self) -> Vec<(i8, NullBuffer)> {
664 self.fields
665 .iter()
666 .enumerate()
667 .filter_map(|(type_id, field)| Some((type_id as i8, field.as_ref()?.logical_nulls()?)))
668 .filter(|(_, nulls)| nulls.null_count() > 0)
669 .collect()
670 }
671}
672
673impl From<ArrayData> for UnionArray {
674 fn from(data: ArrayData) -> Self {
675 let (fields, mode) = match data.data_type() {
676 DataType::Union(fields, mode) => (fields, *mode),
677 d => panic!("UnionArray expected ArrayData with type Union got {d}"),
678 };
679 let (type_ids, offsets) = match mode {
680 UnionMode::Sparse => (
681 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
682 None,
683 ),
684 UnionMode::Dense => (
685 ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
686 Some(ScalarBuffer::new(
687 data.buffers()[1].clone(),
688 data.offset(),
689 data.len(),
690 )),
691 ),
692 };
693
694 let max_id = fields.iter().map(|(i, _)| i).max().unwrap_or_default() as usize;
695 let mut boxed_fields = vec![None; max_id + 1];
696 for (cd, (field_id, _)) in data.child_data().iter().zip(fields.iter()) {
697 boxed_fields[field_id as usize] = Some(make_array(cd.clone()));
698 }
699 Self {
700 data_type: data.data_type().clone(),
701 type_ids,
702 offsets,
703 fields: boxed_fields,
704 }
705 }
706}
707
708impl From<UnionArray> for ArrayData {
709 fn from(array: UnionArray) -> Self {
710 let len = array.len();
711 let f = match &array.data_type {
712 DataType::Union(f, _) => f,
713 _ => unreachable!(),
714 };
715 let buffers = match array.offsets {
716 Some(o) => vec![array.type_ids.into_inner(), o.into_inner()],
717 None => vec![array.type_ids.into_inner()],
718 };
719
720 let child = f
721 .iter()
722 .map(|(i, _)| array.fields[i as usize].as_ref().unwrap().to_data())
723 .collect();
724
725 let builder = ArrayDataBuilder::new(array.data_type)
726 .len(len)
727 .buffers(buffers)
728 .child_data(child);
729 unsafe { builder.build_unchecked() }
730 }
731}
732
733impl Array for UnionArray {
734 fn as_any(&self) -> &dyn Any {
735 self
736 }
737
738 fn to_data(&self) -> ArrayData {
739 self.clone().into()
740 }
741
742 fn into_data(self) -> ArrayData {
743 self.into()
744 }
745
746 fn data_type(&self) -> &DataType {
747 &self.data_type
748 }
749
750 fn slice(&self, offset: usize, length: usize) -> ArrayRef {
751 Arc::new(self.slice(offset, length))
752 }
753
754 fn len(&self) -> usize {
755 self.type_ids.len()
756 }
757
758 fn is_empty(&self) -> bool {
759 self.type_ids.is_empty()
760 }
761
762 fn shrink_to_fit(&mut self) {
763 self.type_ids.shrink_to_fit();
764 if let Some(offsets) = &mut self.offsets {
765 offsets.shrink_to_fit();
766 }
767 for array in self.fields.iter_mut().flatten() {
768 array.shrink_to_fit();
769 }
770 self.fields.shrink_to_fit();
771 }
772
773 fn offset(&self) -> usize {
774 0
775 }
776
777 fn nulls(&self) -> Option<&NullBuffer> {
778 None
779 }
780
781 fn logical_nulls(&self) -> Option<NullBuffer> {
782 let fields = match self.data_type() {
783 DataType::Union(fields, _) => fields,
784 _ => unreachable!(),
785 };
786
787 if fields.len() <= 1 {
788 return self.fields.iter().find_map(|field_opt| {
789 field_opt
790 .as_ref()
791 .and_then(|field| field.logical_nulls())
792 .map(|logical_nulls| {
793 if self.is_dense() {
794 self.gather_nulls(vec![(0, logical_nulls)]).into()
795 } else {
796 logical_nulls
797 }
798 })
799 });
800 }
801
802 let logical_nulls = self.fields_logical_nulls();
803
804 if logical_nulls.is_empty() {
805 return None;
806 }
807
808 let fully_null_count = logical_nulls
809 .iter()
810 .filter(|(_, nulls)| nulls.null_count() == nulls.len())
811 .count();
812
813 if fully_null_count == fields.len() {
814 if let Some((_, exactly_sized)) = logical_nulls
815 .iter()
816 .find(|(_, nulls)| nulls.len() == self.len())
817 {
818 return Some(exactly_sized.clone());
819 }
820
821 if let Some((_, bigger)) = logical_nulls
822 .iter()
823 .find(|(_, nulls)| nulls.len() > self.len())
824 {
825 return Some(bigger.slice(0, self.len()));
826 }
827
828 return Some(NullBuffer::new_null(self.len()));
829 }
830
831 let boolean_buffer = match &self.offsets {
832 Some(_) => self.gather_nulls(logical_nulls),
833 None => {
834 let gather_relative_cost = if cfg!(target_feature = "avx2") {
842 10
843 } else if cfg!(target_feature = "sse4.1") {
844 3
845 } else if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
846 2
848 } else {
849 0
853 };
854
855 let strategies = [
856 (SparseStrategy::Gather, gather_relative_cost, true),
857 (
858 SparseStrategy::MaskAllFieldsWithNullsSkipOne,
859 fields.len() - 1,
860 fields.len() == logical_nulls.len(),
861 ),
862 (
863 SparseStrategy::MaskSkipWithoutNulls,
864 logical_nulls.len(),
865 true,
866 ),
867 (
868 SparseStrategy::MaskSkipFullyNull,
869 fields.len() - fully_null_count,
870 true,
871 ),
872 ];
873
874 let (strategy, _, _) = strategies
875 .iter()
876 .filter(|(_, _, applicable)| *applicable)
877 .min_by_key(|(_, cost, _)| cost)
878 .unwrap();
879
880 match strategy {
881 SparseStrategy::Gather => self.gather_nulls(logical_nulls),
882 SparseStrategy::MaskAllFieldsWithNullsSkipOne => {
883 self.mask_sparse_all_with_nulls_skip_one(logical_nulls)
884 }
885 SparseStrategy::MaskSkipWithoutNulls => {
886 self.mask_sparse_skip_without_nulls(logical_nulls)
887 }
888 SparseStrategy::MaskSkipFullyNull => {
889 self.mask_sparse_skip_fully_null(logical_nulls)
890 }
891 }
892 }
893 };
894
895 let null_buffer = NullBuffer::from(boolean_buffer);
896
897 if null_buffer.null_count() > 0 {
898 Some(null_buffer)
899 } else {
900 None
901 }
902 }
903
904 fn is_nullable(&self) -> bool {
905 self.fields
906 .iter()
907 .flatten()
908 .any(|field| field.is_nullable())
909 }
910
911 fn get_buffer_memory_size(&self) -> usize {
912 let mut sum = self.type_ids.inner().capacity();
913 if let Some(o) = self.offsets.as_ref() {
914 sum += o.inner().capacity()
915 }
916 self.fields
917 .iter()
918 .flat_map(|x| x.as_ref().map(|x| x.get_buffer_memory_size()))
919 .sum::<usize>()
920 + sum
921 }
922
923 fn get_array_memory_size(&self) -> usize {
924 let mut sum = self.type_ids.inner().capacity();
925 if let Some(o) = self.offsets.as_ref() {
926 sum += o.inner().capacity()
927 }
928 std::mem::size_of::<Self>()
929 + self
930 .fields
931 .iter()
932 .flat_map(|x| x.as_ref().map(|x| x.get_array_memory_size()))
933 .sum::<usize>()
934 + sum
935 }
936}
937
938impl std::fmt::Debug for UnionArray {
939 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
940 let header = if self.is_dense() {
941 "UnionArray(Dense)\n["
942 } else {
943 "UnionArray(Sparse)\n["
944 };
945 writeln!(f, "{header}")?;
946
947 writeln!(f, "-- type id buffer:")?;
948 writeln!(f, "{:?}", self.type_ids)?;
949
950 if let Some(offsets) = &self.offsets {
951 writeln!(f, "-- offsets buffer:")?;
952 writeln!(f, "{offsets:?}")?;
953 }
954
955 let fields = match self.data_type() {
956 DataType::Union(fields, _) => fields,
957 _ => unreachable!(),
958 };
959
960 for (type_id, field) in fields.iter() {
961 let child = self.child(type_id);
962 writeln!(
963 f,
964 "-- child {}: \"{}\" ({:?})",
965 type_id,
966 field.name(),
967 field.data_type()
968 )?;
969 std::fmt::Debug::fmt(child, f)?;
970 writeln!(f)?;
971 }
972 writeln!(f, "]")
973 }
974}
975
976enum SparseStrategy {
981 Gather,
983 MaskAllFieldsWithNullsSkipOne,
985 MaskSkipWithoutNulls,
987 MaskSkipFullyNull,
989}
990
991#[derive(Copy, Clone)]
992#[repr(usize)]
993enum Mask {
994 Zero = 0,
995 #[allow(clippy::enum_clike_unportable_variant)]
997 Max = usize::MAX,
998}
999
1000fn selection_mask(type_ids_chunk: &[i8], type_id: i8) -> u64 {
1001 type_ids_chunk
1002 .iter()
1003 .copied()
1004 .enumerate()
1005 .fold(0, |packed, (bit_idx, v)| {
1006 packed | (((v == type_id) as u64) << bit_idx)
1007 })
1008}
1009
1010fn without_nulls_selected(type_ids_chunk: &[i8], without_nulls_ids: &[i8]) -> u64 {
1012 without_nulls_ids
1013 .iter()
1014 .fold(0, |fully_valid_selected, field_type_id| {
1015 fully_valid_selected | selection_mask(type_ids_chunk, *field_type_id)
1016 })
1017}
1018
1019#[cfg(test)]
1020mod tests {
1021 use super::*;
1022 use std::collections::HashSet;
1023
1024 use crate::array::Int8Type;
1025 use crate::builder::UnionBuilder;
1026 use crate::cast::AsArray;
1027 use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
1028 use crate::{Float64Array, Int32Array, Int64Array, StringArray};
1029 use crate::{Int8Array, RecordBatch};
1030 use arrow_buffer::Buffer;
1031 use arrow_schema::{Field, Schema};
1032
1033 #[test]
1034 fn test_dense_i32() {
1035 let mut builder = UnionBuilder::new_dense();
1036 builder.append::<Int32Type>("a", 1).unwrap();
1037 builder.append::<Int32Type>("b", 2).unwrap();
1038 builder.append::<Int32Type>("c", 3).unwrap();
1039 builder.append::<Int32Type>("a", 4).unwrap();
1040 builder.append::<Int32Type>("c", 5).unwrap();
1041 builder.append::<Int32Type>("a", 6).unwrap();
1042 builder.append::<Int32Type>("b", 7).unwrap();
1043 let union = builder.build().unwrap();
1044
1045 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1046 let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
1047 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1048
1049 assert_eq!(*union.type_ids(), expected_type_ids);
1051 for (i, id) in expected_type_ids.iter().enumerate() {
1052 assert_eq!(id, &union.type_id(i));
1053 }
1054
1055 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1057 for (i, id) in expected_offsets.iter().enumerate() {
1058 assert_eq!(union.value_offset(i), *id as usize);
1059 }
1060
1061 assert_eq!(
1063 *union.child(0).as_primitive::<Int32Type>().values(),
1064 [1_i32, 4, 6]
1065 );
1066 assert_eq!(
1067 *union.child(1).as_primitive::<Int32Type>().values(),
1068 [2_i32, 7]
1069 );
1070 assert_eq!(
1071 *union.child(2).as_primitive::<Int32Type>().values(),
1072 [3_i32, 5]
1073 );
1074
1075 assert_eq!(expected_array_values.len(), union.len());
1076 for (i, expected_value) in expected_array_values.iter().enumerate() {
1077 assert!(!union.is_null(i));
1078 let slot = union.value(i);
1079 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1080 assert_eq!(slot.len(), 1);
1081 let value = slot.value(0);
1082 assert_eq!(expected_value, &value);
1083 }
1084 }
1085
1086 #[test]
1087 fn slice_union_array_single_field() {
1088 let union_array = {
1091 let mut builder = UnionBuilder::new_dense();
1092 builder.append::<Int32Type>("a", 1).unwrap();
1093 builder.append_null::<Int32Type>("a").unwrap();
1094 builder.append::<Int32Type>("a", 3).unwrap();
1095 builder.append_null::<Int32Type>("a").unwrap();
1096 builder.append::<Int32Type>("a", 4).unwrap();
1097 builder.build().unwrap()
1098 };
1099
1100 let union_slice = union_array.slice(1, 3);
1102 let logical_nulls = union_slice.logical_nulls().unwrap();
1103
1104 assert_eq!(logical_nulls.len(), 3);
1105 assert!(logical_nulls.is_null(0));
1106 assert!(logical_nulls.is_valid(1));
1107 assert!(logical_nulls.is_null(2));
1108 }
1109
1110 #[test]
1111 #[cfg_attr(miri, ignore)]
1112 fn test_dense_i32_large() {
1113 let mut builder = UnionBuilder::new_dense();
1114
1115 let expected_type_ids = vec![0_i8; 1024];
1116 let expected_offsets: Vec<_> = (0..1024).collect();
1117 let expected_array_values: Vec<_> = (1..=1024).collect();
1118
1119 expected_array_values
1120 .iter()
1121 .for_each(|v| builder.append::<Int32Type>("a", *v).unwrap());
1122
1123 let union = builder.build().unwrap();
1124
1125 assert_eq!(*union.type_ids(), expected_type_ids);
1127 for (i, id) in expected_type_ids.iter().enumerate() {
1128 assert_eq!(id, &union.type_id(i));
1129 }
1130
1131 assert_eq!(*union.offsets().unwrap(), expected_offsets);
1133 for (i, id) in expected_offsets.iter().enumerate() {
1134 assert_eq!(union.value_offset(i), *id as usize);
1135 }
1136
1137 for (i, expected_value) in expected_array_values.iter().enumerate() {
1138 assert!(!union.is_null(i));
1139 let slot = union.value(i);
1140 let slot = slot.as_primitive::<Int32Type>();
1141 assert_eq!(slot.len(), 1);
1142 let value = slot.value(0);
1143 assert_eq!(expected_value, &value);
1144 }
1145 }
1146
1147 #[test]
1148 fn test_dense_mixed() {
1149 let mut builder = UnionBuilder::new_dense();
1150 builder.append::<Int32Type>("a", 1).unwrap();
1151 builder.append::<Int64Type>("c", 3).unwrap();
1152 builder.append::<Int32Type>("a", 4).unwrap();
1153 builder.append::<Int64Type>("c", 5).unwrap();
1154 builder.append::<Int32Type>("a", 6).unwrap();
1155 let union = builder.build().unwrap();
1156
1157 assert_eq!(5, union.len());
1158 for i in 0..union.len() {
1159 let slot = union.value(i);
1160 assert!(!union.is_null(i));
1161 match i {
1162 0 => {
1163 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1164 assert_eq!(slot.len(), 1);
1165 let value = slot.value(0);
1166 assert_eq!(1_i32, value);
1167 }
1168 1 => {
1169 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1170 assert_eq!(slot.len(), 1);
1171 let value = slot.value(0);
1172 assert_eq!(3_i64, value);
1173 }
1174 2 => {
1175 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1176 assert_eq!(slot.len(), 1);
1177 let value = slot.value(0);
1178 assert_eq!(4_i32, value);
1179 }
1180 3 => {
1181 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1182 assert_eq!(slot.len(), 1);
1183 let value = slot.value(0);
1184 assert_eq!(5_i64, value);
1185 }
1186 4 => {
1187 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1188 assert_eq!(slot.len(), 1);
1189 let value = slot.value(0);
1190 assert_eq!(6_i32, value);
1191 }
1192 _ => unreachable!(),
1193 }
1194 }
1195 }
1196
1197 #[test]
1198 fn test_dense_mixed_with_nulls() {
1199 let mut builder = UnionBuilder::new_dense();
1200 builder.append::<Int32Type>("a", 1).unwrap();
1201 builder.append::<Int64Type>("c", 3).unwrap();
1202 builder.append::<Int32Type>("a", 10).unwrap();
1203 builder.append_null::<Int32Type>("a").unwrap();
1204 builder.append::<Int32Type>("a", 6).unwrap();
1205 let union = builder.build().unwrap();
1206
1207 assert_eq!(5, union.len());
1208 for i in 0..union.len() {
1209 let slot = union.value(i);
1210 match i {
1211 0 => {
1212 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1213 assert!(!slot.is_null(0));
1214 assert_eq!(slot.len(), 1);
1215 let value = slot.value(0);
1216 assert_eq!(1_i32, value);
1217 }
1218 1 => {
1219 let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
1220 assert!(!slot.is_null(0));
1221 assert_eq!(slot.len(), 1);
1222 let value = slot.value(0);
1223 assert_eq!(3_i64, value);
1224 }
1225 2 => {
1226 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1227 assert!(!slot.is_null(0));
1228 assert_eq!(slot.len(), 1);
1229 let value = slot.value(0);
1230 assert_eq!(10_i32, value);
1231 }
1232 3 => assert!(slot.is_null(0)),
1233 4 => {
1234 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1235 assert!(!slot.is_null(0));
1236 assert_eq!(slot.len(), 1);
1237 let value = slot.value(0);
1238 assert_eq!(6_i32, value);
1239 }
1240 _ => unreachable!(),
1241 }
1242 }
1243 }
1244
1245 #[test]
1246 fn test_dense_mixed_with_nulls_and_offset() {
1247 let mut builder = UnionBuilder::new_dense();
1248 builder.append::<Int32Type>("a", 1).unwrap();
1249 builder.append::<Int64Type>("c", 3).unwrap();
1250 builder.append::<Int32Type>("a", 10).unwrap();
1251 builder.append_null::<Int32Type>("a").unwrap();
1252 builder.append::<Int32Type>("a", 6).unwrap();
1253 let union = builder.build().unwrap();
1254
1255 let slice = union.slice(2, 3);
1256 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1257
1258 assert_eq!(3, new_union.len());
1259 for i in 0..new_union.len() {
1260 let slot = new_union.value(i);
1261 match i {
1262 0 => {
1263 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1264 assert!(!slot.is_null(0));
1265 assert_eq!(slot.len(), 1);
1266 let value = slot.value(0);
1267 assert_eq!(10_i32, value);
1268 }
1269 1 => assert!(slot.is_null(0)),
1270 2 => {
1271 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1272 assert!(!slot.is_null(0));
1273 assert_eq!(slot.len(), 1);
1274 let value = slot.value(0);
1275 assert_eq!(6_i32, value);
1276 }
1277 _ => unreachable!(),
1278 }
1279 }
1280 }
1281
1282 #[test]
1283 fn test_dense_mixed_with_str() {
1284 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1285 let int_array = Int32Array::from(vec![5, 6]);
1286 let float_array = Float64Array::from(vec![10.0]);
1287
1288 let type_ids = [1, 0, 0, 2, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();
1289 let offsets = [0, 0, 1, 0, 2, 1]
1290 .into_iter()
1291 .collect::<ScalarBuffer<i32>>();
1292
1293 let fields = [
1294 (0, Arc::new(Field::new("A", DataType::Utf8, false))),
1295 (1, Arc::new(Field::new("B", DataType::Int32, false))),
1296 (2, Arc::new(Field::new("C", DataType::Float64, false))),
1297 ]
1298 .into_iter()
1299 .collect::<UnionFields>();
1300 let children = [
1301 Arc::new(string_array) as Arc<dyn Array>,
1302 Arc::new(int_array),
1303 Arc::new(float_array),
1304 ]
1305 .into_iter()
1306 .collect();
1307 let array =
1308 UnionArray::try_new(fields, type_ids.clone(), Some(offsets.clone()), children).unwrap();
1309
1310 assert_eq!(*array.type_ids(), type_ids);
1312 for (i, id) in type_ids.iter().enumerate() {
1313 assert_eq!(id, &array.type_id(i));
1314 }
1315
1316 assert_eq!(*array.offsets().unwrap(), offsets);
1318 for (i, id) in offsets.iter().enumerate() {
1319 assert_eq!(*id as usize, array.value_offset(i));
1320 }
1321
1322 assert_eq!(6, array.len());
1324
1325 let slot = array.value(0);
1326 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1327 assert_eq!(5, value);
1328
1329 let slot = array.value(1);
1330 let value = slot
1331 .as_any()
1332 .downcast_ref::<StringArray>()
1333 .unwrap()
1334 .value(0);
1335 assert_eq!("foo", value);
1336
1337 let slot = array.value(2);
1338 let value = slot
1339 .as_any()
1340 .downcast_ref::<StringArray>()
1341 .unwrap()
1342 .value(0);
1343 assert_eq!("bar", value);
1344
1345 let slot = array.value(3);
1346 let value = slot
1347 .as_any()
1348 .downcast_ref::<Float64Array>()
1349 .unwrap()
1350 .value(0);
1351 assert_eq!(10.0, value);
1352
1353 let slot = array.value(4);
1354 let value = slot
1355 .as_any()
1356 .downcast_ref::<StringArray>()
1357 .unwrap()
1358 .value(0);
1359 assert_eq!("baz", value);
1360
1361 let slot = array.value(5);
1362 let value = slot.as_any().downcast_ref::<Int32Array>().unwrap().value(0);
1363 assert_eq!(6, value);
1364 }
1365
1366 #[test]
1367 fn test_sparse_i32() {
1368 let mut builder = UnionBuilder::new_sparse();
1369 builder.append::<Int32Type>("a", 1).unwrap();
1370 builder.append::<Int32Type>("b", 2).unwrap();
1371 builder.append::<Int32Type>("c", 3).unwrap();
1372 builder.append::<Int32Type>("a", 4).unwrap();
1373 builder.append::<Int32Type>("c", 5).unwrap();
1374 builder.append::<Int32Type>("a", 6).unwrap();
1375 builder.append::<Int32Type>("b", 7).unwrap();
1376 let union = builder.build().unwrap();
1377
1378 let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
1379 let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];
1380
1381 assert_eq!(*union.type_ids(), expected_type_ids);
1383 for (i, id) in expected_type_ids.iter().enumerate() {
1384 assert_eq!(id, &union.type_id(i));
1385 }
1386
1387 assert!(union.offsets().is_none());
1389
1390 assert_eq!(
1392 *union.child(0).as_primitive::<Int32Type>().values(),
1393 [1_i32, 0, 0, 4, 0, 6, 0],
1394 );
1395 assert_eq!(
1396 *union.child(1).as_primitive::<Int32Type>().values(),
1397 [0_i32, 2_i32, 0, 0, 0, 0, 7]
1398 );
1399 assert_eq!(
1400 *union.child(2).as_primitive::<Int32Type>().values(),
1401 [0_i32, 0, 3_i32, 0, 5, 0, 0]
1402 );
1403
1404 assert_eq!(expected_array_values.len(), union.len());
1405 for (i, expected_value) in expected_array_values.iter().enumerate() {
1406 assert!(!union.is_null(i));
1407 let slot = union.value(i);
1408 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1409 assert_eq!(slot.len(), 1);
1410 let value = slot.value(0);
1411 assert_eq!(expected_value, &value);
1412 }
1413 }
1414
1415 #[test]
1416 fn test_sparse_mixed() {
1417 let mut builder = UnionBuilder::new_sparse();
1418 builder.append::<Int32Type>("a", 1).unwrap();
1419 builder.append::<Float64Type>("c", 3.0).unwrap();
1420 builder.append::<Int32Type>("a", 4).unwrap();
1421 builder.append::<Float64Type>("c", 5.0).unwrap();
1422 builder.append::<Int32Type>("a", 6).unwrap();
1423 let union = builder.build().unwrap();
1424
1425 let expected_type_ids = vec![0_i8, 1, 0, 1, 0];
1426
1427 assert_eq!(*union.type_ids(), expected_type_ids);
1429 for (i, id) in expected_type_ids.iter().enumerate() {
1430 assert_eq!(id, &union.type_id(i));
1431 }
1432
1433 assert!(union.offsets().is_none());
1435
1436 for i in 0..union.len() {
1437 let slot = union.value(i);
1438 assert!(!union.is_null(i));
1439 match i {
1440 0 => {
1441 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1442 assert_eq!(slot.len(), 1);
1443 let value = slot.value(0);
1444 assert_eq!(1_i32, value);
1445 }
1446 1 => {
1447 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1448 assert_eq!(slot.len(), 1);
1449 let value = slot.value(0);
1450 assert_eq!(value, 3_f64);
1451 }
1452 2 => {
1453 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1454 assert_eq!(slot.len(), 1);
1455 let value = slot.value(0);
1456 assert_eq!(4_i32, value);
1457 }
1458 3 => {
1459 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1460 assert_eq!(slot.len(), 1);
1461 let value = slot.value(0);
1462 assert_eq!(5_f64, value);
1463 }
1464 4 => {
1465 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1466 assert_eq!(slot.len(), 1);
1467 let value = slot.value(0);
1468 assert_eq!(6_i32, value);
1469 }
1470 _ => unreachable!(),
1471 }
1472 }
1473 }
1474
1475 #[test]
1476 fn test_sparse_mixed_with_nulls() {
1477 let mut builder = UnionBuilder::new_sparse();
1478 builder.append::<Int32Type>("a", 1).unwrap();
1479 builder.append_null::<Int32Type>("a").unwrap();
1480 builder.append::<Float64Type>("c", 3.0).unwrap();
1481 builder.append::<Int32Type>("a", 4).unwrap();
1482 let union = builder.build().unwrap();
1483
1484 let expected_type_ids = vec![0_i8, 0, 1, 0];
1485
1486 assert_eq!(*union.type_ids(), expected_type_ids);
1488 for (i, id) in expected_type_ids.iter().enumerate() {
1489 assert_eq!(id, &union.type_id(i));
1490 }
1491
1492 assert!(union.offsets().is_none());
1494
1495 for i in 0..union.len() {
1496 let slot = union.value(i);
1497 match i {
1498 0 => {
1499 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1500 assert!(!slot.is_null(0));
1501 assert_eq!(slot.len(), 1);
1502 let value = slot.value(0);
1503 assert_eq!(1_i32, value);
1504 }
1505 1 => assert!(slot.is_null(0)),
1506 2 => {
1507 let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
1508 assert!(!slot.is_null(0));
1509 assert_eq!(slot.len(), 1);
1510 let value = slot.value(0);
1511 assert_eq!(value, 3_f64);
1512 }
1513 3 => {
1514 let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
1515 assert!(!slot.is_null(0));
1516 assert_eq!(slot.len(), 1);
1517 let value = slot.value(0);
1518 assert_eq!(4_i32, value);
1519 }
1520 _ => unreachable!(),
1521 }
1522 }
1523 }
1524
1525 #[test]
1526 fn test_sparse_mixed_with_nulls_and_offset() {
1527 let mut builder = UnionBuilder::new_sparse();
1528 builder.append::<Int32Type>("a", 1).unwrap();
1529 builder.append_null::<Int32Type>("a").unwrap();
1530 builder.append::<Float64Type>("c", 3.0).unwrap();
1531 builder.append_null::<Float64Type>("c").unwrap();
1532 builder.append::<Int32Type>("a", 4).unwrap();
1533 let union = builder.build().unwrap();
1534
1535 let slice = union.slice(1, 4);
1536 let new_union = slice.as_any().downcast_ref::<UnionArray>().unwrap();
1537
1538 assert_eq!(4, new_union.len());
1539 for i in 0..new_union.len() {
1540 let slot = new_union.value(i);
1541 match i {
1542 0 => assert!(slot.is_null(0)),
1543 1 => {
1544 let slot = slot.as_primitive::<Float64Type>();
1545 assert!(!slot.is_null(0));
1546 assert_eq!(slot.len(), 1);
1547 let value = slot.value(0);
1548 assert_eq!(value, 3_f64);
1549 }
1550 2 => assert!(slot.is_null(0)),
1551 3 => {
1552 let slot = slot.as_primitive::<Int32Type>();
1553 assert!(!slot.is_null(0));
1554 assert_eq!(slot.len(), 1);
1555 let value = slot.value(0);
1556 assert_eq!(4_i32, value);
1557 }
1558 _ => unreachable!(),
1559 }
1560 }
1561 }
1562
1563 fn test_union_validity(union_array: &UnionArray) {
1564 assert_eq!(union_array.null_count(), 0);
1565
1566 for i in 0..union_array.len() {
1567 assert!(!union_array.is_null(i));
1568 assert!(union_array.is_valid(i));
1569 }
1570 }
1571
1572 #[test]
1573 fn test_union_array_validity() {
1574 let mut builder = UnionBuilder::new_sparse();
1575 builder.append::<Int32Type>("a", 1).unwrap();
1576 builder.append_null::<Int32Type>("a").unwrap();
1577 builder.append::<Float64Type>("c", 3.0).unwrap();
1578 builder.append_null::<Float64Type>("c").unwrap();
1579 builder.append::<Int32Type>("a", 4).unwrap();
1580 let union = builder.build().unwrap();
1581
1582 test_union_validity(&union);
1583
1584 let mut builder = UnionBuilder::new_dense();
1585 builder.append::<Int32Type>("a", 1).unwrap();
1586 builder.append_null::<Int32Type>("a").unwrap();
1587 builder.append::<Float64Type>("c", 3.0).unwrap();
1588 builder.append_null::<Float64Type>("c").unwrap();
1589 builder.append::<Int32Type>("a", 4).unwrap();
1590 let union = builder.build().unwrap();
1591
1592 test_union_validity(&union);
1593 }
1594
1595 #[test]
1596 fn test_type_check() {
1597 let mut builder = UnionBuilder::new_sparse();
1598 builder.append::<Float32Type>("a", 1.0).unwrap();
1599 let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
1600 assert!(
1601 err.contains(
1602 "Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"
1603 ),
1604 "{}",
1605 err
1606 );
1607 }
1608
1609 #[test]
1610 fn slice_union_array() {
1611 fn create_union(mut builder: UnionBuilder) -> UnionArray {
1613 builder.append::<Int32Type>("a", 1).unwrap();
1614 builder.append_null::<Int32Type>("a").unwrap();
1615 builder.append::<Float64Type>("c", 3.0).unwrap();
1616 builder.append_null::<Float64Type>("c").unwrap();
1617 builder.append::<Int32Type>("a", 4).unwrap();
1618 builder.build().unwrap()
1619 }
1620
1621 fn create_batch(union: UnionArray) -> RecordBatch {
1622 let schema = Schema::new(vec![Field::new(
1623 "struct_array",
1624 union.data_type().clone(),
1625 true,
1626 )]);
1627
1628 RecordBatch::try_new(Arc::new(schema), vec![Arc::new(union)]).unwrap()
1629 }
1630
1631 fn test_slice_union(record_batch_slice: RecordBatch) {
1632 let union_slice = record_batch_slice
1633 .column(0)
1634 .as_any()
1635 .downcast_ref::<UnionArray>()
1636 .unwrap();
1637
1638 assert_eq!(union_slice.type_id(0), 0);
1639 assert_eq!(union_slice.type_id(1), 1);
1640 assert_eq!(union_slice.type_id(2), 1);
1641
1642 let slot = union_slice.value(0);
1643 let array = slot.as_primitive::<Int32Type>();
1644 assert_eq!(array.len(), 1);
1645 assert!(array.is_null(0));
1646
1647 let slot = union_slice.value(1);
1648 let array = slot.as_primitive::<Float64Type>();
1649 assert_eq!(array.len(), 1);
1650 assert!(array.is_valid(0));
1651 assert_eq!(array.value(0), 3.0);
1652
1653 let slot = union_slice.value(2);
1654 let array = slot.as_primitive::<Float64Type>();
1655 assert_eq!(array.len(), 1);
1656 assert!(array.is_null(0));
1657 }
1658
1659 let builder = UnionBuilder::new_sparse();
1661 let record_batch = create_batch(create_union(builder));
1662 let record_batch_slice = record_batch.slice(1, 3);
1664 test_slice_union(record_batch_slice);
1665
1666 let builder = UnionBuilder::new_dense();
1668 let record_batch = create_batch(create_union(builder));
1669 let record_batch_slice = record_batch.slice(1, 3);
1671 test_slice_union(record_batch_slice);
1672 }
1673
1674 #[test]
1675 fn test_custom_type_ids() {
1676 let data_type = DataType::Union(
1677 UnionFields::new(
1678 vec![8, 4, 9],
1679 vec![
1680 Field::new("strings", DataType::Utf8, false),
1681 Field::new("integers", DataType::Int32, false),
1682 Field::new("floats", DataType::Float64, false),
1683 ],
1684 ),
1685 UnionMode::Dense,
1686 );
1687
1688 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1689 let int_array = Int32Array::from(vec![5, 6, 4]);
1690 let float_array = Float64Array::from(vec![10.0]);
1691
1692 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1693 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1694
1695 let data = ArrayData::builder(data_type)
1696 .len(7)
1697 .buffers(vec![type_ids, value_offsets])
1698 .child_data(vec![
1699 string_array.into_data(),
1700 int_array.into_data(),
1701 float_array.into_data(),
1702 ])
1703 .build()
1704 .unwrap();
1705
1706 let array = UnionArray::from(data);
1707
1708 let v = array.value(0);
1709 assert_eq!(v.data_type(), &DataType::Int32);
1710 assert_eq!(v.len(), 1);
1711 assert_eq!(v.as_primitive::<Int32Type>().value(0), 5);
1712
1713 let v = array.value(1);
1714 assert_eq!(v.data_type(), &DataType::Utf8);
1715 assert_eq!(v.len(), 1);
1716 assert_eq!(v.as_string::<i32>().value(0), "foo");
1717
1718 let v = array.value(2);
1719 assert_eq!(v.data_type(), &DataType::Int32);
1720 assert_eq!(v.len(), 1);
1721 assert_eq!(v.as_primitive::<Int32Type>().value(0), 6);
1722
1723 let v = array.value(3);
1724 assert_eq!(v.data_type(), &DataType::Utf8);
1725 assert_eq!(v.len(), 1);
1726 assert_eq!(v.as_string::<i32>().value(0), "bar");
1727
1728 let v = array.value(4);
1729 assert_eq!(v.data_type(), &DataType::Float64);
1730 assert_eq!(v.len(), 1);
1731 assert_eq!(v.as_primitive::<Float64Type>().value(0), 10.0);
1732
1733 let v = array.value(5);
1734 assert_eq!(v.data_type(), &DataType::Int32);
1735 assert_eq!(v.len(), 1);
1736 assert_eq!(v.as_primitive::<Int32Type>().value(0), 4);
1737
1738 let v = array.value(6);
1739 assert_eq!(v.data_type(), &DataType::Utf8);
1740 assert_eq!(v.len(), 1);
1741 assert_eq!(v.as_string::<i32>().value(0), "baz");
1742 }
1743
1744 #[test]
1745 fn into_parts() {
1746 let mut builder = UnionBuilder::new_dense();
1747 builder.append::<Int32Type>("a", 1).unwrap();
1748 builder.append::<Int8Type>("b", 2).unwrap();
1749 builder.append::<Int32Type>("a", 3).unwrap();
1750 let dense_union = builder.build().unwrap();
1751
1752 let field = [
1753 &Arc::new(Field::new("a", DataType::Int32, false)),
1754 &Arc::new(Field::new("b", DataType::Int8, false)),
1755 ];
1756 let (union_fields, type_ids, offsets, children) = dense_union.into_parts();
1757 assert_eq!(
1758 union_fields
1759 .iter()
1760 .map(|(_, field)| field)
1761 .collect::<Vec<_>>(),
1762 field
1763 );
1764 assert_eq!(type_ids, [0, 1, 0]);
1765 assert!(offsets.is_some());
1766 assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
1767
1768 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1769 assert!(result.is_ok());
1770 assert_eq!(result.unwrap().len(), 3);
1771
1772 let mut builder = UnionBuilder::new_sparse();
1773 builder.append::<Int32Type>("a", 1).unwrap();
1774 builder.append::<Int8Type>("b", 2).unwrap();
1775 builder.append::<Int32Type>("a", 3).unwrap();
1776 let sparse_union = builder.build().unwrap();
1777
1778 let (union_fields, type_ids, offsets, children) = sparse_union.into_parts();
1779 assert_eq!(type_ids, [0, 1, 0]);
1780 assert!(offsets.is_none());
1781
1782 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1783 assert!(result.is_ok());
1784 assert_eq!(result.unwrap().len(), 3);
1785 }
1786
1787 #[test]
1788 fn into_parts_custom_type_ids() {
1789 let set_field_type_ids: [i8; 3] = [8, 4, 9];
1790 let data_type = DataType::Union(
1791 UnionFields::new(
1792 set_field_type_ids,
1793 [
1794 Field::new("strings", DataType::Utf8, false),
1795 Field::new("integers", DataType::Int32, false),
1796 Field::new("floats", DataType::Float64, false),
1797 ],
1798 ),
1799 UnionMode::Dense,
1800 );
1801 let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
1802 let int_array = Int32Array::from(vec![5, 6, 4]);
1803 let float_array = Float64Array::from(vec![10.0]);
1804 let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
1805 let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
1806 let data = ArrayData::builder(data_type)
1807 .len(7)
1808 .buffers(vec![type_ids, value_offsets])
1809 .child_data(vec![
1810 string_array.into_data(),
1811 int_array.into_data(),
1812 float_array.into_data(),
1813 ])
1814 .build()
1815 .unwrap();
1816 let array = UnionArray::from(data);
1817
1818 let (union_fields, type_ids, offsets, children) = array.into_parts();
1819 assert_eq!(
1820 type_ids.iter().collect::<HashSet<_>>(),
1821 set_field_type_ids.iter().collect::<HashSet<_>>()
1822 );
1823 let result = UnionArray::try_new(union_fields, type_ids, offsets, children);
1824 assert!(result.is_ok());
1825 let array = result.unwrap();
1826 assert_eq!(array.len(), 7);
1827 }
1828
1829 #[test]
1830 fn test_invalid() {
1831 let fields = UnionFields::new(
1832 [3, 2],
1833 [
1834 Field::new("a", DataType::Utf8, false),
1835 Field::new("b", DataType::Utf8, false),
1836 ],
1837 );
1838 let children = vec![
1839 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1840 Arc::new(StringArray::from_iter_values(["c", "d"])) as _,
1841 ];
1842
1843 let type_ids = vec![3, 3, 2].into();
1844 let err =
1845 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1846 assert_eq!(
1847 err.to_string(),
1848 "Invalid argument error: Sparse union child arrays must be equal in length to the length of the union"
1849 );
1850
1851 let type_ids = vec![1, 2].into();
1852 let err =
1853 UnionArray::try_new(fields.clone(), type_ids, None, children.clone()).unwrap_err();
1854 assert_eq!(
1855 err.to_string(),
1856 "Invalid argument error: Type Ids values must match one of the field type ids"
1857 );
1858
1859 let type_ids = vec![7, 2].into();
1860 let err = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap_err();
1861 assert_eq!(
1862 err.to_string(),
1863 "Invalid argument error: Type Ids values must match one of the field type ids"
1864 );
1865
1866 let children = vec![
1867 Arc::new(StringArray::from_iter_values(["a", "b"])) as _,
1868 Arc::new(StringArray::from_iter_values(["c"])) as _,
1869 ];
1870 let type_ids = ScalarBuffer::from(vec![3_i8, 3, 2]);
1871 let offsets = Some(vec![0, 1, 0].into());
1872 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone()).unwrap();
1873
1874 let offsets = Some(vec![0, 1, 1].into());
1875 let err = UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children.clone())
1876 .unwrap_err();
1877
1878 assert_eq!(
1879 err.to_string(),
1880 "Invalid argument error: Offsets must be positive and within the length of the Array"
1881 );
1882
1883 let offsets = Some(vec![0, 1].into());
1884 let err =
1885 UnionArray::try_new(fields.clone(), type_ids.clone(), offsets, children).unwrap_err();
1886
1887 assert_eq!(
1888 err.to_string(),
1889 "Invalid argument error: Type Ids and Offsets lengths must match"
1890 );
1891
1892 let err = UnionArray::try_new(fields.clone(), type_ids, None, vec![]).unwrap_err();
1893
1894 assert_eq!(
1895 err.to_string(),
1896 "Invalid argument error: Union fields length must match child arrays length"
1897 );
1898 }
1899
1900 #[test]
1901 fn test_logical_nulls_fast_paths() {
1902 let array = UnionArray::try_new(UnionFields::empty(), vec![].into(), None, vec![]).unwrap();
1904
1905 assert_eq!(array.logical_nulls(), None);
1906
1907 let fields = UnionFields::new(
1908 [1, 3],
1909 [
1910 Field::new("a", DataType::Int8, false), Field::new("b", DataType::Int8, false), ],
1913 );
1914 let array = UnionArray::try_new(
1915 fields,
1916 vec![1].into(),
1917 None,
1918 vec![
1919 Arc::new(Int8Array::from_value(5, 1)),
1920 Arc::new(Int8Array::from_value(5, 1)),
1921 ],
1922 )
1923 .unwrap();
1924
1925 assert_eq!(array.logical_nulls(), None);
1926
1927 let nullable_fields = UnionFields::new(
1928 [1, 3],
1929 [
1930 Field::new("a", DataType::Int8, true), Field::new("b", DataType::Int8, true), ],
1933 );
1934 let array = UnionArray::try_new(
1935 nullable_fields.clone(),
1936 vec![1, 1].into(),
1937 None,
1938 vec![
1939 Arc::new(Int8Array::from_value(-5, 2)), Arc::new(Int8Array::from_value(-5, 2)), ],
1942 )
1943 .unwrap();
1944
1945 assert_eq!(array.logical_nulls(), None);
1946
1947 let array = UnionArray::try_new(
1948 nullable_fields.clone(),
1949 vec![1, 1].into(),
1950 None,
1951 vec![
1952 Arc::new(Int8Array::new_null(2)), Arc::new(Int8Array::new_null(2)), ],
1956 )
1957 .unwrap();
1958
1959 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1960
1961 let array = UnionArray::try_new(
1962 nullable_fields.clone(),
1963 vec![1, 1].into(),
1964 Some(vec![0, 1].into()),
1965 vec![
1966 Arc::new(Int8Array::new_null(3)), Arc::new(Int8Array::new_null(3)), ],
1970 )
1971 .unwrap();
1972
1973 assert_eq!(array.logical_nulls(), Some(NullBuffer::new_null(2)));
1974 }
1975
1976 #[test]
1977 fn test_dense_union_logical_nulls_gather() {
1978 let int_array = Int32Array::from(vec![1, 2]);
1980 let float_array = Float64Array::from(vec![Some(3.2), None]);
1981 let str_array = StringArray::new_null(1);
1982 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
1983 let offsets = [0, 1, 0, 1, 0, 0]
1984 .into_iter()
1985 .collect::<ScalarBuffer<i32>>();
1986
1987 let children = vec![
1988 Arc::new(int_array) as Arc<dyn Array>,
1989 Arc::new(float_array),
1990 Arc::new(str_array),
1991 ];
1992
1993 let array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap();
1994
1995 let expected = BooleanBuffer::from(vec![true, true, true, false, false, false]);
1996
1997 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
1998 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
1999 }
2000
2001 #[test]
2002 fn test_sparse_union_logical_nulls_mask_all_nulls_skip_one() {
2003 let fields: UnionFields = [
2004 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2005 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2006 ]
2007 .into_iter()
2008 .collect();
2009
2010 let int_array = Int32Array::new_null(4);
2012 let float_array = Float64Array::from(vec![None, None, Some(3.2), None]);
2013 let type_ids = [1, 1, 3, 3].into_iter().collect::<ScalarBuffer<i8>>();
2014
2015 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2016
2017 let array = UnionArray::try_new(fields.clone(), type_ids, None, children).unwrap();
2018
2019 let expected = BooleanBuffer::from(vec![false, false, true, false]);
2020
2021 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2022 assert_eq!(
2023 expected,
2024 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2025 );
2026
2027 let len = 2 * 64 + 32;
2029
2030 let int_array = Int32Array::new_null(len);
2031 let float_array = Float64Array::from_iter([Some(3.2), None].into_iter().cycle().take(len));
2032 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3].into_iter().cycle().take(len));
2033
2034 let array = UnionArray::try_new(
2035 fields,
2036 type_ids,
2037 None,
2038 vec![Arc::new(int_array), Arc::new(float_array)],
2039 )
2040 .unwrap();
2041
2042 let expected =
2043 BooleanBuffer::from_iter([false, false, true, false].into_iter().cycle().take(len));
2044
2045 assert_eq!(array.len(), len);
2046 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2047 assert_eq!(
2048 expected,
2049 array.mask_sparse_all_with_nulls_skip_one(array.fields_logical_nulls())
2050 );
2051 }
2052
2053 #[test]
2054 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_valid() {
2055 let int_array = Int32Array::from_value(2, 6);
2057 let float_array = Float64Array::from_value(4.2, 6);
2058 let str_array = StringArray::new_null(6);
2059 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2060
2061 let children = vec![
2062 Arc::new(int_array) as Arc<dyn Array>,
2063 Arc::new(float_array),
2064 Arc::new(str_array),
2065 ];
2066
2067 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2068
2069 let expected = BooleanBuffer::from(vec![true, true, true, true, false, false]);
2070
2071 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2072 assert_eq!(
2073 expected,
2074 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2075 );
2076
2077 let len = 2 * 64 + 32;
2079
2080 let int_array = Int32Array::from_value(2, len);
2081 let float_array = Float64Array::from_value(4.2, len);
2082 let str_array = StringArray::from_iter([None, Some("a")].into_iter().cycle().take(len));
2083 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2084
2085 let children = vec![
2086 Arc::new(int_array) as Arc<dyn Array>,
2087 Arc::new(float_array),
2088 Arc::new(str_array),
2089 ];
2090
2091 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2092
2093 let expected = BooleanBuffer::from_iter(
2094 [true, true, true, true, false, true]
2095 .into_iter()
2096 .cycle()
2097 .take(len),
2098 );
2099
2100 assert_eq!(array.len(), len);
2101 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2102 assert_eq!(
2103 expected,
2104 array.mask_sparse_skip_without_nulls(array.fields_logical_nulls())
2105 );
2106 }
2107
2108 #[test]
2109 fn test_sparse_union_logical_mask_mixed_nulls_skip_fully_null() {
2110 let int_array = Int32Array::new_null(6);
2112 let float_array = Float64Array::from_value(4.2, 6);
2113 let str_array = StringArray::new_null(6);
2114 let type_ids = [1, 1, 3, 3, 4, 4].into_iter().collect::<ScalarBuffer<i8>>();
2115
2116 let children = vec![
2117 Arc::new(int_array) as Arc<dyn Array>,
2118 Arc::new(float_array),
2119 Arc::new(str_array),
2120 ];
2121
2122 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2123
2124 let expected = BooleanBuffer::from(vec![false, false, true, true, false, false]);
2125
2126 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2127 assert_eq!(
2128 expected,
2129 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2130 );
2131
2132 let len = 2 * 64 + 32;
2134
2135 let int_array = Int32Array::new_null(len);
2136 let float_array = Float64Array::from_value(4.2, len);
2137 let str_array = StringArray::new_null(len);
2138 let type_ids = ScalarBuffer::from_iter([1, 1, 3, 3, 4, 4].into_iter().cycle().take(len));
2139
2140 let children = vec![
2141 Arc::new(int_array) as Arc<dyn Array>,
2142 Arc::new(float_array),
2143 Arc::new(str_array),
2144 ];
2145
2146 let array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap();
2147
2148 let expected = BooleanBuffer::from_iter(
2149 [false, false, true, true, false, false]
2150 .into_iter()
2151 .cycle()
2152 .take(len),
2153 );
2154
2155 assert_eq!(array.len(), len);
2156 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2157 assert_eq!(
2158 expected,
2159 array.mask_sparse_skip_fully_null(array.fields_logical_nulls())
2160 );
2161 }
2162
2163 #[test]
2164 fn test_sparse_union_logical_nulls_gather() {
2165 let n_fields = 50;
2166
2167 let non_null = Int32Array::from_value(2, 4);
2168 let mixed = Int32Array::from(vec![None, None, Some(1), None]);
2169 let fully_null = Int32Array::new_null(4);
2170
2171 let array = UnionArray::try_new(
2172 (1..)
2173 .step_by(2)
2174 .map(|i| {
2175 (
2176 i,
2177 Arc::new(Field::new(format!("f{i}"), DataType::Int32, true)),
2178 )
2179 })
2180 .take(n_fields)
2181 .collect(),
2182 vec![1, 3, 3, 5].into(),
2183 None,
2184 [
2185 Arc::new(non_null) as ArrayRef,
2186 Arc::new(mixed),
2187 Arc::new(fully_null),
2188 ]
2189 .into_iter()
2190 .cycle()
2191 .take(n_fields)
2192 .collect(),
2193 )
2194 .unwrap();
2195
2196 let expected = BooleanBuffer::from(vec![true, false, true, false]);
2197
2198 assert_eq!(expected, array.logical_nulls().unwrap().into_inner());
2199 assert_eq!(expected, array.gather_nulls(array.fields_logical_nulls()));
2200 }
2201
2202 fn union_fields() -> UnionFields {
2203 [
2204 (1, Arc::new(Field::new("A", DataType::Int32, true))),
2205 (3, Arc::new(Field::new("B", DataType::Float64, true))),
2206 (4, Arc::new(Field::new("C", DataType::Utf8, true))),
2207 ]
2208 .into_iter()
2209 .collect()
2210 }
2211
2212 #[test]
2213 fn test_is_nullable() {
2214 assert!(!create_union_array(false, false).is_nullable());
2215 assert!(create_union_array(true, false).is_nullable());
2216 assert!(create_union_array(false, true).is_nullable());
2217 assert!(create_union_array(true, true).is_nullable());
2218 }
2219
2220 fn create_union_array(int_nullable: bool, float_nullable: bool) -> UnionArray {
2227 let int_array = if int_nullable {
2228 Int32Array::from(vec![Some(1), None, Some(3)])
2229 } else {
2230 Int32Array::from(vec![1, 2, 3])
2231 };
2232 let float_array = if float_nullable {
2233 Float64Array::from(vec![Some(3.2), None, Some(4.2)])
2234 } else {
2235 Float64Array::from(vec![3.2, 4.2, 5.2])
2236 };
2237 let type_ids = [0, 1, 0].into_iter().collect::<ScalarBuffer<i8>>();
2238 let offsets = [0, 0, 0].into_iter().collect::<ScalarBuffer<i32>>();
2239 let union_fields = [
2240 (0, Arc::new(Field::new("A", DataType::Int32, true))),
2241 (1, Arc::new(Field::new("B", DataType::Float64, true))),
2242 ]
2243 .into_iter()
2244 .collect::<UnionFields>();
2245
2246 let children = vec![Arc::new(int_array) as Arc<dyn Array>, Arc::new(float_array)];
2247
2248 UnionArray::try_new(union_fields, type_ids, Some(offsets), children).unwrap()
2249 }
2250}