1use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26 ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{
30 ArrowNativeType, BooleanBuffer, NullBuffer, OffsetBuffer, RunEndBuffer, ScalarBuffer, bit_util,
31};
32use arrow_buffer::{Buffer, MutableBuffer};
33use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
34use arrow_data::transform::MutableArrayData;
35use arrow_schema::*;
36
37const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
44
45#[derive(Debug)]
57pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
58
59impl<'a> SlicesIterator<'a> {
60 pub fn new(filter: &'a BooleanArray) -> Self {
62 filter.values().into()
63 }
64}
65
66impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
67 fn from(filter: &'a BooleanBuffer) -> Self {
68 Self(filter.set_slices())
69 }
70}
71
72impl Iterator for SlicesIterator<'_> {
73 type Item = (usize, usize);
74
75 fn next(&mut self) -> Option<Self::Item> {
76 self.0.next()
77 }
78}
79
80struct IndexIterator<'a> {
85 remaining: usize,
86 iter: BitIndexIterator<'a>,
87}
88
89impl<'a> IndexIterator<'a> {
90 fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
91 assert_eq!(filter.null_count(), 0);
92 let iter = filter.values().set_indices();
93 Self { remaining, iter }
94 }
95
96 pub fn collect(mut self) -> Vec<usize> {
100 let len = self.remaining;
101 let mut result = Vec::with_capacity(len);
102 let ptr: *mut usize = result.as_mut_ptr();
103 for i in 0..len {
104 let next = self.iter.next();
107 debug_assert!(next.is_some(), "IndexIterator exhausted early");
108 unsafe {
109 *ptr.add(i) = next.unwrap_unchecked();
110 }
111 }
112 unsafe {
114 result.set_len(len);
115 }
116 result
117 }
118}
119
120impl Iterator for IndexIterator<'_> {
121 type Item = usize;
122
123 fn next(&mut self) -> Option<Self::Item> {
124 if self.remaining != 0 {
125 let next = self.iter.next().expect("IndexIterator exhausted early");
128 self.remaining -= 1;
129 return Some(next);
131 }
132 None
133 }
134
135 fn size_hint(&self) -> (usize, Option<usize>) {
136 (self.remaining, Some(self.remaining))
137 }
138}
139
140fn filter_count(filter: &BooleanArray) -> usize {
142 filter.values().count_set_bits()
143}
144
145pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
173 let nulls = filter.nulls().unwrap();
174 let mask = filter.values() & nulls.inner();
175 BooleanArray::new(mask, None)
176}
177
178pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
207 let mut filter_builder = FilterBuilder::new(predicate);
208
209 if FilterBuilder::is_optimize_beneficial(values.data_type()) {
210 filter_builder = filter_builder.optimize();
213 }
214
215 let predicate = filter_builder.build();
216
217 filter_array(values, &predicate)
218}
219
220pub fn filter_record_batch(
231 record_batch: &RecordBatch,
232 predicate: &BooleanArray,
233) -> Result<RecordBatch, ArrowError> {
234 let mut filter_builder = FilterBuilder::new(predicate);
235 let num_cols = record_batch.num_columns();
236 if num_cols > 1
237 || (num_cols > 0
238 && FilterBuilder::is_optimize_beneficial(
239 record_batch.schema_ref().field(0).data_type(),
240 ))
241 {
242 filter_builder = filter_builder.optimize();
245 }
246 let filter = filter_builder.build();
247
248 filter.filter_record_batch(record_batch)
249}
250
251#[derive(Debug)]
253pub struct FilterBuilder {
254 filter: BooleanArray,
255 count: usize,
256 strategy: IterationStrategy,
257}
258
259impl FilterBuilder {
260 pub fn new(filter: &BooleanArray) -> Self {
262 let filter = match filter.null_count() {
263 0 => filter.clone(),
264 _ => prep_null_mask_filter(filter),
265 };
266
267 let count = filter_count(&filter);
268 let strategy = IterationStrategy::default_strategy(filter.len(), count);
269
270 Self {
271 filter,
272 count,
273 strategy,
274 }
275 }
276
277 pub fn optimize(mut self) -> Self {
288 match self.strategy {
289 IterationStrategy::SlicesIterator => {
290 let slices = SlicesIterator::new(&self.filter).collect();
291 self.strategy = IterationStrategy::Slices(slices)
292 }
293 IterationStrategy::IndexIterator => {
294 let indices = IndexIterator::new(&self.filter, self.count).collect();
295 self.strategy = IterationStrategy::Indices(indices)
296 }
297 _ => {}
298 }
299 self
300 }
301
302 pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
307 match data_type {
308 DataType::Struct(fields) => {
309 fields.len() > 1
310 || fields.len() == 1
311 && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
312 }
313 DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
314 _ => false,
315 }
316 }
317
318 pub fn build(self) -> FilterPredicate {
320 FilterPredicate {
321 filter: self.filter,
322 count: self.count,
323 strategy: self.strategy,
324 }
325 }
326}
327
328#[derive(Debug)]
330enum IterationStrategy {
331 SlicesIterator,
333 IndexIterator,
335 Indices(Vec<usize>),
337 Slices(Vec<(usize, usize)>),
339 All,
341 None,
343}
344
345impl IterationStrategy {
346 fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
349 if filter_length == 0 || filter_count == 0 {
350 return IterationStrategy::None;
351 }
352
353 if filter_count == filter_length {
354 return IterationStrategy::All;
355 }
356
357 let selectivity_frac = filter_count as f64 / filter_length as f64;
362 if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
363 return IterationStrategy::SlicesIterator;
364 }
365 IterationStrategy::IndexIterator
366 }
367}
368
369#[derive(Debug)]
371pub struct FilterPredicate {
372 filter: BooleanArray,
373 count: usize,
374 strategy: IterationStrategy,
375}
376
377impl FilterPredicate {
378 pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
380 filter_array(values, self)
381 }
382
383 pub fn filter_record_batch(
388 &self,
389 record_batch: &RecordBatch,
390 ) -> Result<RecordBatch, ArrowError> {
391 let filtered_arrays = record_batch
392 .columns()
393 .iter()
394 .map(|a| filter_array(a, self))
395 .collect::<Result<Vec<_>, _>>()?;
396
397 unsafe {
400 Ok(RecordBatch::new_unchecked(
401 record_batch.schema(),
402 filtered_arrays,
403 self.count,
404 ))
405 }
406 }
407
408 pub fn count(&self) -> usize {
410 self.count
411 }
412
413 pub fn filter_nulls(&self, nulls: Option<&NullBuffer>) -> Option<NullBuffer> {
420 let (null_count, nulls) = filter_null_mask(nulls, self)?;
421 let buffer = BooleanBuffer::new(nulls, 0, self.count);
422
423 debug_assert_eq!(null_count, buffer.len() - buffer.count_set_bits());
424 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
427 }
428}
429
430fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
431 if predicate.filter.len() > values.len() {
432 return Err(ArrowError::InvalidArgumentError(format!(
433 "Filter predicate of length {} is larger than target array of length {}",
434 predicate.filter.len(),
435 values.len()
436 )));
437 }
438
439 match predicate.strategy {
440 IterationStrategy::None => Ok(new_empty_array(values.data_type())),
441 IterationStrategy::All => Ok(values.slice(0, predicate.count)),
442 _ => downcast_primitive_array! {
444 values => Ok(Arc::new(filter_primitive(values, predicate))),
445 DataType::Boolean => {
446 let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
447 Ok(Arc::new(filter_boolean(values, predicate)))
448 }
449 DataType::Utf8 => {
450 Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
451 }
452 DataType::LargeUtf8 => {
453 Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
454 }
455 DataType::Utf8View => {
456 Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
457 }
458 DataType::Binary => {
459 Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
460 }
461 DataType::LargeBinary => {
462 Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
463 }
464 DataType::BinaryView => {
465 Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
466 }
467 DataType::FixedSizeBinary(_) => {
468 Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
469 }
470 DataType::ListView(_) => {
471 Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
472 }
473 DataType::LargeListView(_) => {
474 Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
475 }
476 DataType::RunEndEncoded(_, _) => {
477 downcast_run_array!{
478 values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
479 t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
480 }
481 }
482 DataType::Dictionary(_, _) => downcast_dictionary_array! {
483 values => Ok(Arc::new(filter_dict(values, predicate))),
484 t => unimplemented!("Filter not supported for dictionary type {:?}", t)
485 }
486 DataType::Struct(_) => {
487 Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
488 }
489 DataType::Union(_, UnionMode::Sparse) => {
490 Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
491 }
492 _ => {
493 let data = values.to_data();
494 let mut mutable = MutableArrayData::new(
496 vec![&data],
497 false,
498 predicate.count,
499 );
500
501 match &predicate.strategy {
502 IterationStrategy::Slices(slices) => {
503 slices
504 .iter()
505 .for_each(|(start, end)| mutable.extend(0, *start, *end));
506 }
507 _ => {
508 let iter = SlicesIterator::new(&predicate.filter);
509 iter.for_each(|(start, end)| mutable.extend(0, start, end));
510 }
511 }
512
513 let data = mutable.freeze();
514 Ok(make_array(data))
515 }
516 },
517 }
518}
519
520fn filter_run_end_array<R: RunEndIndexType>(
522 array: &RunArray<R>,
523 predicate: &FilterPredicate,
524) -> Result<RunArray<R>, ArrowError>
525where
526 R::Native: Into<i64> + From<bool>,
527 R::Native: AddAssign,
528{
529 let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
530 let start_physical = run_ends.get_start_physical_index();
531 let end_physical = run_ends.get_end_physical_index();
532 let physical_len = end_physical - start_physical + 1;
533
534 let mut new_run_ends = vec![R::default_value(); physical_len];
535 let offset = run_ends.offset() as u64;
536
537 let mut start = 0u64;
538 let mut j = 0;
539 let mut count = R::default_value();
540 let filter_values = predicate.filter.values();
541 let run_ends = run_ends.inner();
542
543 let pred: BooleanArray = BooleanBuffer::collect_bool(physical_len, |i| {
544 let mut keep = false;
545 let mut end = (run_ends[i + start_physical].into() as u64).saturating_sub(offset);
546 let difference = end.saturating_sub(filter_values.len() as u64);
547 end -= difference;
548
549 for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
551 count += R::Native::from(pred);
552 keep |= pred
553 }
554 new_run_ends[j] = count;
556 j += keep as usize;
557
558 start = end;
559 keep
560 })
561 .into();
562
563 new_run_ends.truncate(j);
564
565 let values = array.values_slice();
566 let values = filter(values.as_ref(), &pred)?;
567
568 let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
569 RunArray::try_new(&run_ends, &values)
570}
571
572fn filter_null_mask(
579 nulls: Option<&NullBuffer>,
580 predicate: &FilterPredicate,
581) -> Option<(usize, Buffer)> {
582 let nulls = nulls?;
583 if nulls.null_count() == 0 {
584 return None;
585 }
586
587 let nulls = filter_bits(nulls.inner(), predicate);
588 let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
591
592 if null_count == 0 {
593 return None;
594 }
595
596 Some((null_count, nulls))
597}
598
599fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
601 let src = buffer.values();
602 let offset = buffer.offset();
603 assert!(buffer.len() >= predicate.filter.len());
604
605 match &predicate.strategy {
606 IterationStrategy::IndexIterator => {
607 let bits =
608 IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
610 bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
611 });
612
613 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
615 }
616 IterationStrategy::Indices(indices) => {
617 let bits = indices.iter().map(|src_idx| unsafe {
619 bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
620 });
621 unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
623 }
624 IterationStrategy::SlicesIterator => {
625 let mut builder = BooleanBufferBuilder::new(predicate.count);
626 for (start, end) in SlicesIterator::new(&predicate.filter) {
627 builder.append_packed_range(start + offset..end + offset, src)
628 }
629 builder.into()
630 }
631 IterationStrategy::Slices(slices) => {
632 let mut builder = BooleanBufferBuilder::new(predicate.count);
633 for (start, end) in slices {
634 builder.append_packed_range(*start + offset..*end + offset, src)
635 }
636 builder.into()
637 }
638 IterationStrategy::All | IterationStrategy::None => unreachable!(),
639 }
640}
641
642fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
644 let buffer = filter_bits(array.values(), predicate);
645 let values = BooleanBuffer::new(buffer, 0, predicate.count);
646 let nulls = predicate.filter_nulls(array.nulls());
647
648 BooleanArray::new(values, nulls)
649}
650
651#[inline(never)]
652fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
653 assert!(values.len() >= predicate.filter.len());
654
655 match &predicate.strategy {
656 IterationStrategy::SlicesIterator => {
657 let mut buffer = Vec::with_capacity(predicate.count);
658 for (start, end) in SlicesIterator::new(&predicate.filter) {
659 buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
661 }
662 buffer.into()
663 }
664 IterationStrategy::Slices(slices) => {
665 let mut buffer = Vec::with_capacity(predicate.count);
666 for (start, end) in slices {
667 buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
669 }
670 buffer.into()
671 }
672 IterationStrategy::IndexIterator => {
673 let iter = IndexIterator::new(&predicate.filter, predicate.count)
675 .map(|x| unsafe { *values.get_unchecked(x) });
676
677 unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
679 }
680 IterationStrategy::Indices(indices) => {
681 let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
683 iter.collect::<Vec<_>>().into()
684 }
685 IterationStrategy::All | IterationStrategy::None => unreachable!(),
686 }
687}
688
689fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
691where
692 T: ArrowPrimitiveType,
693{
694 let buffer = filter_native(array.values(), predicate);
695 let values = ScalarBuffer::new(buffer, 0, predicate.count);
696 let nulls = predicate.filter_nulls(array.nulls());
697 let filtered = PrimitiveArray::new(values, nulls);
698
699 if array.data_type() == &T::DATA_TYPE {
701 filtered
702 } else {
703 filtered.with_data_type(array.data_type().clone())
704 }
705}
706
707struct FilterBytes<'a, OffsetSize> {
712 src_offsets: &'a [OffsetSize],
713 src_values: &'a [u8],
714 dst_offsets: Vec<OffsetSize>,
715 dst_values: Vec<u8>,
716 cur_offset: OffsetSize,
717}
718
719impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
720where
721 OffsetSize: OffsetSizeTrait,
722{
723 fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
724 where
725 T: ByteArrayType<Offset = OffsetSize>,
726 {
727 let dst_values = Vec::new();
728 let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
729 let cur_offset = OffsetSize::from_usize(0).unwrap();
730
731 dst_offsets.push(cur_offset);
732
733 Self {
734 src_offsets: array.value_offsets(),
735 src_values: array.value_data(),
736 dst_offsets,
737 dst_values,
738 cur_offset,
739 }
740 }
741
742 #[inline]
744 fn get_value_offset(&self, idx: usize) -> usize {
745 self.src_offsets[idx].as_usize()
746 }
747
748 #[inline]
750 fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
751 let start = self.get_value_offset(idx);
753 let end = self.get_value_offset(idx + 1);
754 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
755 (start, end, len)
756 }
757
758 fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
759 self.dst_offsets.extend(iter.map(|idx| {
760 let start = self.src_offsets[idx].as_usize();
761 let end = self.src_offsets[idx + 1].as_usize();
762 let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
763 self.cur_offset += len;
764
765 self.cur_offset
766 }));
767 }
768
769 fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
771 self.dst_values.reserve_exact(self.cur_offset.as_usize());
772
773 for idx in iter {
774 let start = self.src_offsets[idx].as_usize();
775 let end = self.src_offsets[idx + 1].as_usize();
776 self.dst_values
777 .extend_from_slice(&self.src_values[start..end]);
778 }
779 }
780
781 fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
782 self.dst_offsets.reserve_exact(count);
783 for (start, end) in iter {
784 for idx in start..end {
786 let (_, _, len) = self.get_value_range(idx);
787 self.cur_offset += len;
788 self.dst_offsets.push(self.cur_offset);
789 }
790 }
791 }
792
793 fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
795 self.dst_values.reserve_exact(self.cur_offset.as_usize());
796
797 for (start, end) in iter {
798 let value_start = self.get_value_offset(start);
799 let value_end = self.get_value_offset(end);
800 self.dst_values
801 .extend_from_slice(&self.src_values[value_start..value_end]);
802 }
803 }
804}
805
806fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
811where
812 T: ByteArrayType,
813{
814 let mut filter = FilterBytes::new(predicate.count, array);
815
816 match &predicate.strategy {
817 IterationStrategy::SlicesIterator => {
818 filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
819 filter.extend_slices(SlicesIterator::new(&predicate.filter))
820 }
821 IterationStrategy::Slices(slices) => {
822 filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
823 filter.extend_slices(slices.iter().cloned())
824 }
825 IterationStrategy::IndexIterator => {
826 filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
827 filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
828 }
829 IterationStrategy::Indices(indices) => {
830 filter.extend_offsets_idx(indices.iter().cloned());
831 filter.extend_idx(indices.iter().cloned())
832 }
833 IterationStrategy::All | IterationStrategy::None => unreachable!(),
834 }
835
836 let offsets = unsafe { OffsetBuffer::new_unchecked(filter.dst_offsets.into()) };
839 let nulls = predicate.filter_nulls(array.nulls());
840
841 unsafe { GenericByteArray::new_unchecked(offsets, filter.dst_values.into(), nulls) }
845}
846
847fn filter_byte_view<T: ByteViewType>(
849 array: &GenericByteViewArray<T>,
850 predicate: &FilterPredicate,
851) -> GenericByteViewArray<T> {
852 let new_view_buffer = filter_native(array.views(), predicate);
853 let views = ScalarBuffer::new(new_view_buffer, 0, predicate.count);
854 let buffers = array.data_buffers().to_vec();
855 let nulls = predicate.filter_nulls(array.nulls());
856
857 unsafe { GenericByteViewArray::new_unchecked(views, buffers, nulls) }
861}
862
863fn filter_fixed_size_binary(
864 array: &FixedSizeBinaryArray,
865 predicate: &FilterPredicate,
866) -> FixedSizeBinaryArray {
867 let values: &[u8] = array.values();
868 let value_length = array.value_length() as usize;
869 let calculate_offset_from_index = |index: usize| index * value_length;
870 let buffer = match &predicate.strategy {
871 IterationStrategy::SlicesIterator => {
872 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
873 for (start, end) in SlicesIterator::new(&predicate.filter) {
874 buffer.extend_from_slice(
875 &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
876 );
877 }
878 buffer
879 }
880 IterationStrategy::Slices(slices) => {
881 let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
882 for (start, end) in slices {
883 buffer.extend_from_slice(
884 &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
885 );
886 }
887 buffer
888 }
889 IterationStrategy::IndexIterator => {
890 let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
891 &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
892 });
893
894 let mut buffer = MutableBuffer::new(predicate.count * value_length);
895 iter.for_each(|item| buffer.extend_from_slice(item));
896 buffer
897 }
898 IterationStrategy::Indices(indices) => {
899 let iter = indices.iter().map(|x| {
900 &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
901 });
902
903 let mut buffer = MutableBuffer::new(predicate.count * value_length);
904 iter.for_each(|item| buffer.extend_from_slice(item));
905 buffer
906 }
907 IterationStrategy::All | IterationStrategy::None => unreachable!(),
908 };
909
910 let nulls = predicate.filter_nulls(array.nulls());
911
912 FixedSizeBinaryArray::new(array.value_length(), buffer.into(), nulls)
913}
914
915fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
917where
918 T: ArrowDictionaryKeyType,
919 T::Native: num_traits::Num,
920{
921 let builder = filter_primitive::<T>(array.keys(), predicate)
922 .into_data()
923 .into_builder()
924 .data_type(array.data_type().clone())
925 .child_data(vec![array.values().to_data()]);
926
927 DictionaryArray::from(unsafe { builder.build_unchecked() })
930}
931
932fn filter_struct(
934 array: &StructArray,
935 predicate: &FilterPredicate,
936) -> Result<StructArray, ArrowError> {
937 let columns = array
938 .columns()
939 .iter()
940 .map(|column| filter_array(column, predicate))
941 .collect::<Result<_, _>>()?;
942
943 let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
944 let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
945
946 Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
947 } else {
948 None
949 };
950
951 Ok(unsafe {
952 StructArray::new_unchecked_with_length(
953 array.fields().clone(),
954 columns,
955 nulls,
956 predicate.count(),
957 )
958 })
959}
960
961fn filter_sparse_union(
963 array: &UnionArray,
964 predicate: &FilterPredicate,
965) -> Result<UnionArray, ArrowError> {
966 let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
967 unreachable!()
968 };
969
970 let type_ids = filter_primitive(
971 &Int8Array::try_new(array.type_ids().clone(), None)?,
972 predicate,
973 );
974
975 let children = fields
976 .iter()
977 .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
978 .collect::<Result<_, _>>()?;
979
980 Ok(unsafe {
981 UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
982 })
983}
984
985fn filter_list_view<OffsetType: OffsetSizeTrait>(
987 array: &GenericListViewArray<OffsetType>,
988 predicate: &FilterPredicate,
989) -> GenericListViewArray<OffsetType> {
990 let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
991 let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
992
993 let field = match array.data_type() {
994 DataType::ListView(field) | DataType::LargeListView(field) => field.clone(),
995 _ => unreachable!(),
996 };
997 let offsets = ScalarBuffer::new(filtered_offsets, 0, predicate.count);
998 let sizes = ScalarBuffer::new(filtered_sizes, 0, predicate.count);
999 let values = array.values().clone();
1000 let nulls = predicate.filter_nulls(array.nulls());
1001
1002 unsafe { GenericListViewArray::new_unchecked(field, offsets, sizes, values, nulls) }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010 use super::*;
1011 use arrow_array::builder::*;
1012 use arrow_array::cast::as_run_array;
1013 use arrow_array::types::*;
1014 use rand::distr::uniform::{UniformSampler, UniformUsize};
1015 use rand::distr::{Alphanumeric, StandardUniform};
1016 use rand::prelude::*;
1017 use rand::rng;
1018
1019 macro_rules! def_temporal_test {
1020 ($test:ident, $array_type: ident, $data: expr) => {
1021 #[test]
1022 fn $test() {
1023 let a = $data;
1024 let b = BooleanArray::from(vec![true, false, true, false]);
1025 let c = filter(&a, &b).unwrap();
1026 let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
1027 assert_eq!(2, d.len());
1028 assert_eq!(1, d.value(0));
1029 assert_eq!(3, d.value(1));
1030 }
1031 };
1032 }
1033
1034 def_temporal_test!(
1035 test_filter_date32,
1036 Date32Array,
1037 Date32Array::from(vec![1, 2, 3, 4])
1038 );
1039 def_temporal_test!(
1040 test_filter_date64,
1041 Date64Array,
1042 Date64Array::from(vec![1, 2, 3, 4])
1043 );
1044 def_temporal_test!(
1045 test_filter_time32_second,
1046 Time32SecondArray,
1047 Time32SecondArray::from(vec![1, 2, 3, 4])
1048 );
1049 def_temporal_test!(
1050 test_filter_time32_millisecond,
1051 Time32MillisecondArray,
1052 Time32MillisecondArray::from(vec![1, 2, 3, 4])
1053 );
1054 def_temporal_test!(
1055 test_filter_time64_microsecond,
1056 Time64MicrosecondArray,
1057 Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1058 );
1059 def_temporal_test!(
1060 test_filter_time64_nanosecond,
1061 Time64NanosecondArray,
1062 Time64NanosecondArray::from(vec![1, 2, 3, 4])
1063 );
1064 def_temporal_test!(
1065 test_filter_duration_second,
1066 DurationSecondArray,
1067 DurationSecondArray::from(vec![1, 2, 3, 4])
1068 );
1069 def_temporal_test!(
1070 test_filter_duration_millisecond,
1071 DurationMillisecondArray,
1072 DurationMillisecondArray::from(vec![1, 2, 3, 4])
1073 );
1074 def_temporal_test!(
1075 test_filter_duration_microsecond,
1076 DurationMicrosecondArray,
1077 DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1078 );
1079 def_temporal_test!(
1080 test_filter_duration_nanosecond,
1081 DurationNanosecondArray,
1082 DurationNanosecondArray::from(vec![1, 2, 3, 4])
1083 );
1084 def_temporal_test!(
1085 test_filter_timestamp_second,
1086 TimestampSecondArray,
1087 TimestampSecondArray::from(vec![1, 2, 3, 4])
1088 );
1089 def_temporal_test!(
1090 test_filter_timestamp_millisecond,
1091 TimestampMillisecondArray,
1092 TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1093 );
1094 def_temporal_test!(
1095 test_filter_timestamp_microsecond,
1096 TimestampMicrosecondArray,
1097 TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1098 );
1099 def_temporal_test!(
1100 test_filter_timestamp_nanosecond,
1101 TimestampNanosecondArray,
1102 TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1103 );
1104
1105 #[test]
1106 fn test_filter_array_slice() {
1107 let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1108 let b = BooleanArray::from(vec![true, false, false, true]);
1109 let c = filter(&a, &b).unwrap();
1113 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1114 assert_eq!(2, d.len());
1115 assert_eq!(6, d.value(0));
1116 assert_eq!(9, d.value(1));
1117 }
1118
1119 #[test]
1120 fn test_filter_array_low_density() {
1121 let mut data_values = (1..=65).collect::<Vec<i32>>();
1123 let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1124 data_values.extend_from_slice(&[66, 67]);
1126 filter_values.extend_from_slice(&[false, true]);
1127 let a = Int32Array::from(data_values);
1128 let b = BooleanArray::from(filter_values);
1129 let c = filter(&a, &b).unwrap();
1130 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1131 assert_eq!(2, d.len());
1132 assert_eq!(65, d.value(0));
1133 assert_eq!(67, d.value(1));
1134 }
1135
1136 #[test]
1137 fn test_filter_array_high_density() {
1138 let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1140 let mut filter_values = (1..=65)
1141 .map(|i| !matches!(i % 65, 0))
1142 .collect::<Vec<bool>>();
1143 data_values[1] = None;
1145 data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1147 filter_values.extend_from_slice(&[false, true, true, true]);
1148 let a = Int32Array::from(data_values);
1149 let b = BooleanArray::from(filter_values);
1150 let c = filter(&a, &b).unwrap();
1151 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1152 assert_eq!(67, d.len());
1153 assert_eq!(3, d.null_count());
1154 assert_eq!(1, d.value(0));
1155 assert!(d.is_null(1));
1156 assert_eq!(64, d.value(63));
1157 assert!(d.is_null(64));
1158 assert_eq!(67, d.value(65));
1159 }
1160
1161 #[test]
1162 fn test_filter_string_array_simple() {
1163 let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1164 let b = BooleanArray::from(vec![true, false, true, false]);
1165 let c = filter(&a, &b).unwrap();
1166 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1167 assert_eq!(2, d.len());
1168 assert_eq!("hello", d.value(0));
1169 assert_eq!("world", d.value(1));
1170 }
1171
1172 #[test]
1173 fn test_filter_primitive_array_with_null() {
1174 let a = Int32Array::from(vec![Some(5), None]);
1175 let b = BooleanArray::from(vec![false, true]);
1176 let c = filter(&a, &b).unwrap();
1177 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1178 assert_eq!(1, d.len());
1179 assert!(d.is_null(0));
1180 }
1181
1182 #[test]
1183 fn test_filter_string_array_with_null() {
1184 let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1185 let b = BooleanArray::from(vec![true, false, false, true]);
1186 let c = filter(&a, &b).unwrap();
1187 let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1188 assert_eq!(2, d.len());
1189 assert_eq!("hello", d.value(0));
1190 assert!(!d.is_null(0));
1191 assert!(d.is_null(1));
1192 }
1193
1194 #[test]
1195 fn test_filter_binary_array_with_null() {
1196 let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1197 let a = BinaryArray::from(data);
1198 let b = BooleanArray::from(vec![true, false, false, true]);
1199 let c = filter(&a, &b).unwrap();
1200 let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1201 assert_eq!(2, d.len());
1202 assert_eq!(b"hello", d.value(0));
1203 assert!(!d.is_null(0));
1204 assert!(d.is_null(1));
1205 }
1206
1207 fn _test_filter_byte_view<T>()
1208 where
1209 T: ByteViewType,
1210 str: AsRef<T::Native>,
1211 T::Native: PartialEq,
1212 {
1213 let array = {
1214 let mut builder = GenericByteViewBuilder::<T>::new();
1216 builder.append_value("hello");
1217 builder.append_value("world");
1218 builder.append_null();
1219 builder.append_value("large payload over 12 bytes");
1220 builder.append_value("lulu");
1221 builder.finish()
1222 };
1223
1224 {
1225 let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1226 let actual = filter(&array, &predicate).unwrap();
1227
1228 assert_eq!(actual.len(), 3);
1229
1230 let expected = {
1231 let mut builder = GenericByteViewBuilder::<T>::new();
1233 builder.append_value("hello");
1234 builder.append_null();
1235 builder.append_value("large payload over 12 bytes");
1236 builder.finish()
1237 };
1238
1239 assert_eq!(actual.as_ref(), &expected);
1240 }
1241
1242 {
1243 let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1244 let actual = filter(&array, &predicate).unwrap();
1245
1246 assert_eq!(actual.len(), 2);
1247
1248 let expected = {
1249 let mut builder = GenericByteViewBuilder::<T>::new();
1251 builder.append_value("hello");
1252 builder.append_value("lulu");
1253 builder.finish()
1254 };
1255
1256 assert_eq!(actual.as_ref(), &expected);
1257 }
1258 }
1259
1260 #[test]
1261 fn test_filter_string_view() {
1262 _test_filter_byte_view::<StringViewType>()
1263 }
1264
1265 #[test]
1266 fn test_filter_binary_view() {
1267 _test_filter_byte_view::<BinaryViewType>()
1268 }
1269
1270 #[test]
1271 fn test_filter_fixed_binary() {
1272 let v1 = [1_u8, 2];
1273 let v2 = [3_u8, 4];
1274 let v3 = [5_u8, 6];
1275 let v = vec![&v1, &v2, &v3];
1276 let a = FixedSizeBinaryArray::try_from(v).unwrap();
1277 let b = BooleanArray::from(vec![true, false, true]);
1278 let c = filter(&a, &b).unwrap();
1279 let d = c
1280 .as_ref()
1281 .as_any()
1282 .downcast_ref::<FixedSizeBinaryArray>()
1283 .unwrap();
1284 assert_eq!(d.len(), 2);
1285 assert_eq!(d.value(0), &v1);
1286 assert_eq!(d.value(1), &v3);
1287 let c2 = FilterBuilder::new(&b)
1288 .optimize()
1289 .build()
1290 .filter(&a)
1291 .unwrap();
1292 let d2 = c2
1293 .as_ref()
1294 .as_any()
1295 .downcast_ref::<FixedSizeBinaryArray>()
1296 .unwrap();
1297 assert_eq!(d, d2);
1298
1299 let b = BooleanArray::from(vec![false, false, false]);
1300 let c = filter(&a, &b).unwrap();
1301 let d = c
1302 .as_ref()
1303 .as_any()
1304 .downcast_ref::<FixedSizeBinaryArray>()
1305 .unwrap();
1306 assert_eq!(d.len(), 0);
1307
1308 let b = BooleanArray::from(vec![true, true, true]);
1309 let c = filter(&a, &b).unwrap();
1310 let d = c
1311 .as_ref()
1312 .as_any()
1313 .downcast_ref::<FixedSizeBinaryArray>()
1314 .unwrap();
1315 assert_eq!(d.len(), 3);
1316 assert_eq!(d.value(0), &v1);
1317 assert_eq!(d.value(1), &v2);
1318 assert_eq!(d.value(2), &v3);
1319
1320 let b = BooleanArray::from(vec![false, false, true]);
1321 let c = filter(&a, &b).unwrap();
1322 let d = c
1323 .as_ref()
1324 .as_any()
1325 .downcast_ref::<FixedSizeBinaryArray>()
1326 .unwrap();
1327 assert_eq!(d.len(), 1);
1328 assert_eq!(d.value(0), &v3);
1329 let c2 = FilterBuilder::new(&b)
1330 .optimize()
1331 .build()
1332 .filter(&a)
1333 .unwrap();
1334 let d2 = c2
1335 .as_ref()
1336 .as_any()
1337 .downcast_ref::<FixedSizeBinaryArray>()
1338 .unwrap();
1339 assert_eq!(d, d2);
1340 }
1341
1342 #[test]
1343 fn test_filter_array_slice_with_null() {
1344 let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1345 let b = BooleanArray::from(vec![true, false, false, true]);
1346 let c = filter(&a, &b).unwrap();
1350 let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1351 assert_eq!(2, d.len());
1352 assert!(d.is_null(0));
1353 assert!(!d.is_null(1));
1354 assert_eq!(9, d.value(1));
1355 }
1356
1357 #[test]
1358 fn test_filter_run_end_encoding_array() {
1359 let run_ends = Int64Array::from(vec![2, 3, 8]);
1360 let values = Int64Array::from(vec![7, -2, 9]);
1361 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1362 let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1363 let c = filter(&a, &b).unwrap();
1364 let actual: &RunArray<Int64Type> = as_run_array(&c);
1365 assert_eq!(4, actual.len());
1366
1367 let expected = RunArray::try_new(
1368 &Int64Array::from(vec![1, 2, 4]),
1369 &Int64Array::from(vec![7, -2, 9]),
1370 )
1371 .expect("Failed to make expected RunArray test is broken");
1372
1373 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1374 assert_eq!(actual.values(), expected.values())
1375 }
1376
1377 #[test]
1378 fn test_filter_run_end_encoding_array_sliced() {
1379 let run_ends = Int64Array::from(vec![2, 3, 8]);
1380 let values = Int64Array::from(vec![7, -2, 9]);
1381 let a = RunArray::try_new(&run_ends, &values).unwrap(); let a = a.slice(2, 3); let b = BooleanArray::from(vec![true, false, true]);
1384 let result = filter(&a, &b).unwrap();
1385
1386 let result = result.as_run::<Int64Type>();
1387 let result = result.downcast::<Int64Array>().unwrap();
1388
1389 let expected = vec![-2, 9];
1390 let actual = result.into_iter().flatten().collect::<Vec<_>>();
1391 assert_eq!(expected, actual);
1392 }
1393
1394 #[test]
1395 fn test_filter_run_end_encoding_array_remove_value() {
1396 let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1397 let values = Int32Array::from(vec![7, -2, 9, -8]);
1398 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1399 let b = BooleanArray::from(vec![
1400 false, true, false, false, true, false, true, false, false, false,
1401 ]);
1402 let c = filter(&a, &b).unwrap();
1403 let actual: &RunArray<Int32Type> = as_run_array(&c);
1404 assert_eq!(3, actual.len());
1405
1406 let expected =
1407 RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1408 .expect("Failed to make expected RunArray test is broken");
1409
1410 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1411 assert_eq!(actual.values(), expected.values())
1412 }
1413
1414 #[test]
1415 fn test_filter_run_end_encoding_array_remove_all_but_one() {
1416 let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1417 let values = Int16Array::from(vec![7, -2, 9, -8]);
1418 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1419 let b = BooleanArray::from(vec![
1420 false, false, false, false, false, false, true, false, false, false,
1421 ]);
1422 let c = filter(&a, &b).unwrap();
1423 let actual: &RunArray<Int16Type> = as_run_array(&c);
1424 assert_eq!(1, actual.len());
1425
1426 let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1427 .expect("Failed to make expected RunArray test is broken");
1428
1429 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1430 assert_eq!(actual.values(), expected.values())
1431 }
1432
1433 #[test]
1434 fn test_filter_run_end_encoding_array_empty() {
1435 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1436 let values = Int64Array::from(vec![7, -2, 9, -8]);
1437 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1438 let b = BooleanArray::from(vec![
1439 false, false, false, false, false, false, false, false, false, false,
1440 ]);
1441 let c = filter(&a, &b).unwrap();
1442 let actual: &RunArray<Int64Type> = as_run_array(&c);
1443 assert_eq!(0, actual.len());
1444 }
1445
1446 #[test]
1447 fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1448 let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1449 let values = Int64Array::from(vec![7, -2, 9, -8]);
1450 let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1451 let b = BooleanArray::from(vec![false, true, true]);
1452 let c = filter(&a, &b).unwrap();
1453 let actual: &RunArray<Int64Type> = as_run_array(&c);
1454 assert_eq!(2, actual.len());
1455
1456 let expected = RunArray::try_new(
1457 &Int64Array::from(vec![1, 2]),
1458 &Int64Array::from(vec![7, -2]),
1459 )
1460 .expect("Failed to make expected RunArray test is broken");
1461
1462 assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1463 assert_eq!(actual.values(), expected.values())
1464 }
1465
1466 #[test]
1467 fn test_filter_dictionary_array() {
1468 let values = [Some("hello"), None, Some("world"), Some("!")];
1469 let a: Int8DictionaryArray = values.iter().copied().collect();
1470 let b = BooleanArray::from(vec![false, true, true, false]);
1471 let c = filter(&a, &b).unwrap();
1472 let d = c
1473 .as_ref()
1474 .as_any()
1475 .downcast_ref::<Int8DictionaryArray>()
1476 .unwrap();
1477 let value_array = d.values();
1478 let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1479 assert_eq!(3, values.len());
1481 assert_eq!(2, d.len());
1483 assert!(d.is_null(0));
1484 assert_eq!("world", values.value(d.keys().value(1) as usize));
1485 }
1486
1487 #[test]
1488 fn test_filter_list_array() {
1489 let field = Arc::new(Field::new_list_field(DataType::Int32, false));
1490 let offsets = OffsetBuffer::new(vec![0i64, 3, 6, 8, 8].into());
1491 let value_array = Arc::new(Int32Array::from_iter_values(0..8));
1492 let nulls = Some(NullBuffer::from(vec![true, true, true, false]));
1493 let a = LargeListArray::new(field.clone(), offsets, value_array, nulls);
1495 let b = BooleanArray::from(vec![false, true, false, true]);
1496 let result = filter(&a, &b).unwrap();
1497
1498 let offsets = OffsetBuffer::new(vec![0i64, 3, 3].into());
1500 let value_array = Arc::new(Int32Array::from_iter_values([3, 4, 5]));
1501 let nulls = Some(NullBuffer::from(vec![true, false]));
1502 let expected: ArrayRef = Arc::new(LargeListArray::new(field, offsets, value_array, nulls));
1503
1504 assert_eq!(&expected, &result);
1505 }
1506
1507 fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1508 let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1510 list_array.append_value([Some(1), Some(2)]);
1511 list_array.append_null();
1512 list_array.append_value([]);
1513 list_array.append_value([Some(3), Some(4)]);
1514
1515 let list_array = list_array.finish();
1516 let predicate = BooleanArray::from_iter([true, false, true, false]);
1517
1518 let filtered = filter(&list_array, &predicate)
1520 .unwrap()
1521 .as_list_view::<T>()
1522 .clone();
1523
1524 let mut expected =
1525 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1526 expected.append_value([Some(1), Some(2)]);
1527 expected.append_value([]);
1528 let expected = expected.finish();
1529
1530 assert_eq!(&filtered, &expected);
1531 }
1532
1533 fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1534 let mut list_array =
1536 GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1537 list_array.append_value([Some(1), Some(2)]);
1538 list_array.append_null();
1539 list_array.append_value([]);
1540 list_array.append_value([Some(3), Some(4)]);
1541
1542 let list_array = list_array.finish();
1543
1544 let sliced = list_array.slice(1, 3);
1546 let predicate = BooleanArray::from_iter([false, false, true]);
1547
1548 let filtered = filter(&sliced, &predicate)
1550 .unwrap()
1551 .as_list_view::<T>()
1552 .clone();
1553
1554 let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1555 expected.append_value([Some(3), Some(4)]);
1556 let expected = expected.finish();
1557
1558 assert_eq!(&filtered, &expected);
1559 }
1560
1561 #[test]
1562 fn test_filter_list_view_array() {
1563 test_case_filter_list_view::<i32>();
1564 test_case_filter_list_view::<i64>();
1565
1566 test_case_filter_sliced_list_view::<i32>();
1567 test_case_filter_sliced_list_view::<i64>();
1568 }
1569
1570 #[test]
1571 fn test_slice_iterator_bits() {
1572 let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1573 let filter = BooleanArray::from(filter_values);
1574 let filter_count = filter_count(&filter);
1575
1576 let iter = SlicesIterator::new(&filter);
1577 let chunks = iter.collect::<Vec<_>>();
1578
1579 assert_eq!(chunks, vec![(1, 2)]);
1580 assert_eq!(filter_count, 1);
1581 }
1582
1583 #[test]
1584 fn test_slice_iterator_bits1() {
1585 let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1586 let filter = BooleanArray::from(filter_values);
1587 let filter_count = filter_count(&filter);
1588
1589 let iter = SlicesIterator::new(&filter);
1590 let chunks = iter.collect::<Vec<_>>();
1591
1592 assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1593 assert_eq!(filter_count, 64 - 1);
1594 }
1595
1596 #[test]
1597 fn test_slice_iterator_chunk_and_bits() {
1598 let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1599 let filter = BooleanArray::from(filter_values);
1600 let filter_count = filter_count(&filter);
1601
1602 let iter = SlicesIterator::new(&filter);
1603 let chunks = iter.collect::<Vec<_>>();
1604
1605 assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1606 assert_eq!(filter_count, 61 + 61 + 5);
1607 }
1608
1609 #[test]
1610 fn test_null_mask() {
1611 let a = Int64Array::from(vec![Some(1), Some(2), None]);
1612
1613 let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1614 let out = filter(&a, &mask1).unwrap();
1615 assert_eq!(out.as_ref(), &a.slice(0, 2));
1616 }
1617
1618 #[test]
1619 fn test_filter_record_batch_no_columns() {
1620 let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1621 let options = RecordBatchOptions::default().with_row_count(Some(100));
1622 let record_batch =
1623 RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1624 let out = filter_record_batch(&record_batch, &pred).unwrap();
1625
1626 assert_eq!(out.num_rows(), 2);
1627 }
1628
1629 #[test]
1630 fn test_fast_path() {
1631 let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1632
1633 let mask = BooleanArray::from(vec![true, true, true]);
1635 let out = filter(&a, &mask).unwrap();
1636 let b = out
1637 .as_any()
1638 .downcast_ref::<PrimitiveArray<Int64Type>>()
1639 .unwrap();
1640 assert_eq!(&a, b);
1641
1642 let mask = BooleanArray::from(vec![false, false, false]);
1644 let out = filter(&a, &mask).unwrap();
1645 assert_eq!(out.len(), 0);
1646 assert_eq!(out.data_type(), &DataType::Int64);
1647 }
1648
1649 #[test]
1650 fn test_slices() {
1651 let bools = std::iter::repeat_n(true, 10)
1653 .chain(std::iter::repeat_n(false, 30))
1654 .chain(std::iter::repeat_n(true, 20))
1655 .chain(std::iter::repeat_n(false, 17))
1656 .chain(std::iter::repeat_n(true, 4));
1657
1658 let bool_array: BooleanArray = bools.map(Some).collect();
1659
1660 let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1661 let expected = vec![(0, 10), (40, 60), (77, 81)];
1662 assert_eq!(slices, expected);
1663
1664 let len = bool_array.len();
1666 let sliced_array = bool_array.slice(7, len - 10);
1667 let sliced_array = sliced_array
1668 .as_any()
1669 .downcast_ref::<BooleanArray>()
1670 .unwrap();
1671 let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1672 let expected = vec![(0, 3), (33, 53), (70, 71)];
1673 assert_eq!(slices, expected);
1674 }
1675
1676 fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1677 let mut rng = rng();
1678
1679 let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1680 .take(mask_len)
1681 .collect();
1682
1683 let buffer = Buffer::from_iter(bools.iter().cloned());
1684
1685 let truncated_length = mask_len - offset - truncate;
1686
1687 let filter = BooleanArray::new(BooleanBuffer::new(buffer, offset, truncated_length), None);
1688
1689 let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1690 .flat_map(|(start, end)| start..end)
1691 .collect();
1692
1693 let count = filter_count(&filter);
1694 let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1695
1696 let expected_bits: Vec<_> = bools
1697 .iter()
1698 .skip(offset)
1699 .take(truncated_length)
1700 .enumerate()
1701 .flat_map(|(idx, v)| v.then(|| idx))
1702 .collect();
1703
1704 assert_eq!(slice_bits, expected_bits);
1705 assert_eq!(index_bits, expected_bits);
1706 }
1707
1708 #[test]
1709 #[cfg_attr(miri, ignore)]
1710 fn fuzz_test_slices_iterator() {
1711 let mut rng = rng();
1712
1713 let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1714 for _ in 0..100 {
1715 let mask_len = rng.random_range(0..1024);
1716 let max_offset = 64.min(mask_len);
1717 let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1718
1719 let max_truncate = 128.min(mask_len - offset);
1720 let truncate = uusize
1721 .sample(&mut rng)
1722 .checked_rem(max_truncate)
1723 .unwrap_or(0);
1724
1725 test_slices_fuzz(mask_len, offset, truncate);
1726 }
1727
1728 test_slices_fuzz(64, 0, 0);
1729 test_slices_fuzz(64, 8, 0);
1730 test_slices_fuzz(64, 8, 8);
1731 test_slices_fuzz(32, 8, 8);
1732 test_slices_fuzz(32, 5, 9);
1733 }
1734
1735 fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1737 values
1738 .into_iter()
1739 .zip(predicate)
1740 .filter(|(_, x)| **x)
1741 .map(|(a, _)| a)
1742 .collect()
1743 }
1744
1745 fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1747 where
1748 StandardUniform: Distribution<T>,
1749 {
1750 let mut rng = rng();
1751 (0..len)
1752 .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1753 .collect()
1754 }
1755
1756 fn gen_strings(
1758 len: usize,
1759 valid_percent: f64,
1760 str_len_range: std::ops::Range<usize>,
1761 ) -> Vec<Option<String>> {
1762 let mut rng = rng();
1763 (0..len)
1764 .map(|_| {
1765 rng.random_bool(valid_percent).then(|| {
1766 let len = rng.random_range(str_len_range.clone());
1767 (0..len)
1768 .map(|_| char::from(rng.sample(Alphanumeric)))
1769 .collect()
1770 })
1771 })
1772 .collect()
1773 }
1774
1775 fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1777 src.iter().map(|x| x.as_deref())
1778 }
1779
1780 #[test]
1781 #[cfg_attr(miri, ignore)]
1782 fn fuzz_filter() {
1783 let mut rng = rng();
1784
1785 for i in 0..100 {
1786 let filter_percent = match i {
1787 0..=4 => 1.,
1788 5..=10 => 0.,
1789 _ => rng.random_range(0.0..1.0),
1790 };
1791
1792 let valid_percent = rng.random_range(0.0..1.0);
1793
1794 let array_len = rng.random_range(32..256);
1795 let array_offset = rng.random_range(0..10);
1796
1797 let filter_offset = rng.random_range(0..10);
1799 let filter_truncate = rng.random_range(0..10);
1800 let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1801 .take(array_len + filter_offset - filter_truncate)
1802 .collect();
1803
1804 let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1805
1806 let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1808 let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1809 let bools = &bools[filter_offset..];
1810
1811 let values = gen_primitive(array_len + array_offset, valid_percent);
1813 let src = Int32Array::from_iter(values.iter().cloned());
1814
1815 let src = src.slice(array_offset, array_len);
1816 let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1817 let values = &values[array_offset..];
1818
1819 let filtered = filter(src, predicate).unwrap();
1820 let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1821 let actual: Vec<_> = array.iter().collect();
1822
1823 assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1824
1825 let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1827 let src = StringArray::from_iter(as_deref(&strings));
1828
1829 let src = src.slice(array_offset, array_len);
1830 let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1831
1832 let filtered = filter(src, predicate).unwrap();
1833 let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1834 let actual: Vec<_> = array.iter().collect();
1835
1836 let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1837 assert_eq!(actual, expected_strings);
1838
1839 let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1841
1842 let src = src.slice(array_offset, array_len);
1843 let src = src
1844 .as_any()
1845 .downcast_ref::<DictionaryArray<Int32Type>>()
1846 .unwrap();
1847
1848 let filtered = filter(src, predicate).unwrap();
1849
1850 let array = filtered
1851 .as_any()
1852 .downcast_ref::<DictionaryArray<Int32Type>>()
1853 .unwrap();
1854
1855 let values = array
1856 .values()
1857 .as_any()
1858 .downcast_ref::<StringArray>()
1859 .unwrap();
1860
1861 let actual: Vec<_> = array
1862 .keys()
1863 .iter()
1864 .map(|key| key.map(|key| values.value(key as usize)))
1865 .collect();
1866
1867 assert_eq!(actual, expected_strings);
1868 }
1869 }
1870
1871 #[test]
1872 fn test_filter_map() {
1873 let mut builder =
1874 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1875 builder.keys().append_value("key1");
1877 builder.values().append_value(1);
1878 builder.append(true).unwrap();
1879 builder.keys().append_value("key2");
1880 builder.keys().append_value("key3");
1881 builder.values().append_value(2);
1882 builder.values().append_value(3);
1883 builder.append(true).unwrap();
1884 builder.append(false).unwrap();
1885 builder.keys().append_value("key1");
1886 builder.values().append_value(1);
1887 builder.append(true).unwrap();
1888 let maparray = Arc::new(builder.finish()) as ArrayRef;
1889
1890 let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1891 .into_iter()
1892 .collect::<BooleanArray>();
1893 let got = filter(&maparray, &indices).unwrap();
1894
1895 let mut builder =
1896 MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1897 builder.keys().append_value("key1");
1898 builder.values().append_value(1);
1899 builder.append(true).unwrap();
1900 builder.keys().append_value("key1");
1901 builder.values().append_value(1);
1902 builder.append(true).unwrap();
1903 let expected = Arc::new(builder.finish()) as ArrayRef;
1904
1905 assert_eq!(&expected, &got);
1906 }
1907
1908 #[test]
1909 fn test_filter_fixed_size_list_arrays() {
1910 let field = Arc::new(Field::new_list_field(DataType::Int32, false));
1911 let value_array = Arc::new(Int32Array::from_iter_values(0..9));
1912 let array = FixedSizeListArray::new(field, 3, value_array, None);
1913
1914 let filter_array = BooleanArray::from(vec![true, false, false]);
1915
1916 let c = filter(&array, &filter_array).unwrap();
1917 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1918
1919 assert_eq!(filtered.len(), 1);
1920
1921 let list = filtered.value(0);
1922 assert_eq!(
1923 &[0, 1, 2],
1924 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1925 );
1926
1927 let filter_array = BooleanArray::from(vec![true, false, true]);
1928
1929 let c = filter(&array, &filter_array).unwrap();
1930 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1931
1932 assert_eq!(filtered.len(), 2);
1933
1934 let list = filtered.value(0);
1935 assert_eq!(
1936 &[0, 1, 2],
1937 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1938 );
1939 let list = filtered.value(1);
1940 assert_eq!(
1941 &[6, 7, 8],
1942 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1943 );
1944 }
1945
1946 #[test]
1947 fn test_filter_fixed_size_list_arrays_with_null() {
1948 let field = Arc::new(Field::new_list_field(DataType::Int32, false));
1949 let value_array = Arc::new(Int32Array::from_iter_values(0..10));
1950 let nulls = Some(NullBuffer::from(vec![true, false, false, true, true]));
1951 let array = FixedSizeListArray::new(field, 2, value_array, nulls);
1952
1953 let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1954
1955 let c = filter(&array, &filter_array).unwrap();
1956 let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1957
1958 assert_eq!(filtered.len(), 3);
1959
1960 let list = filtered.value(0);
1961 assert_eq!(
1962 &[0, 1],
1963 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1964 );
1965 assert!(filtered.is_null(1));
1966 let list = filtered.value(2);
1967 assert_eq!(
1968 &[6, 7],
1969 list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1970 );
1971 }
1972
1973 fn test_filter_union_array(array: UnionArray) {
1974 let filter_array = BooleanArray::from(vec![true, false, false]);
1975 let c = filter(&array, &filter_array).unwrap();
1976 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1977
1978 let mut builder = UnionBuilder::new_dense();
1979 builder.append::<Int32Type>("A", 1).unwrap();
1980 let expected_array = builder.build().unwrap();
1981
1982 compare_union_arrays(filtered, &expected_array);
1983
1984 let filter_array = BooleanArray::from(vec![true, false, true]);
1985 let c = filter(&array, &filter_array).unwrap();
1986 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1987
1988 let mut builder = UnionBuilder::new_dense();
1989 builder.append::<Int32Type>("A", 1).unwrap();
1990 builder.append::<Int32Type>("A", 34).unwrap();
1991 let expected_array = builder.build().unwrap();
1992
1993 compare_union_arrays(filtered, &expected_array);
1994
1995 let filter_array = BooleanArray::from(vec![true, true, false]);
1996 let c = filter(&array, &filter_array).unwrap();
1997 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1998
1999 let mut builder = UnionBuilder::new_dense();
2000 builder.append::<Int32Type>("A", 1).unwrap();
2001 builder.append::<Float64Type>("B", 3.2).unwrap();
2002 let expected_array = builder.build().unwrap();
2003
2004 compare_union_arrays(filtered, &expected_array);
2005 }
2006
2007 #[test]
2008 fn test_filter_union_array_dense() {
2009 let mut builder = UnionBuilder::new_dense();
2010 builder.append::<Int32Type>("A", 1).unwrap();
2011 builder.append::<Float64Type>("B", 3.2).unwrap();
2012 builder.append::<Int32Type>("A", 34).unwrap();
2013 let array = builder.build().unwrap();
2014
2015 test_filter_union_array(array);
2016 }
2017
2018 #[test]
2019 fn test_filter_run_union_array_dense() {
2020 let mut builder = UnionBuilder::new_dense();
2021 builder.append::<Int32Type>("A", 1).unwrap();
2022 builder.append::<Int32Type>("A", 3).unwrap();
2023 builder.append::<Int32Type>("A", 34).unwrap();
2024 let array = builder.build().unwrap();
2025
2026 let filter_array = BooleanArray::from(vec![true, true, false]);
2027 let c = filter(&array, &filter_array).unwrap();
2028 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2029
2030 let mut builder = UnionBuilder::new_dense();
2031 builder.append::<Int32Type>("A", 1).unwrap();
2032 builder.append::<Int32Type>("A", 3).unwrap();
2033 let expected = builder.build().unwrap();
2034
2035 assert_eq!(filtered.to_data(), expected.to_data());
2036 }
2037
2038 #[test]
2039 fn test_filter_union_array_dense_with_nulls() {
2040 let mut builder = UnionBuilder::new_dense();
2041 builder.append::<Int32Type>("A", 1).unwrap();
2042 builder.append::<Float64Type>("B", 3.2).unwrap();
2043 builder.append_null::<Float64Type>("B").unwrap();
2044 builder.append::<Int32Type>("A", 34).unwrap();
2045 let array = builder.build().unwrap();
2046
2047 let filter_array = BooleanArray::from(vec![true, true, false, false]);
2048 let c = filter(&array, &filter_array).unwrap();
2049 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2050
2051 let mut builder = UnionBuilder::new_dense();
2052 builder.append::<Int32Type>("A", 1).unwrap();
2053 builder.append::<Float64Type>("B", 3.2).unwrap();
2054 let expected_array = builder.build().unwrap();
2055
2056 compare_union_arrays(filtered, &expected_array);
2057
2058 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2059 let c = filter(&array, &filter_array).unwrap();
2060 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2061
2062 let mut builder = UnionBuilder::new_dense();
2063 builder.append::<Int32Type>("A", 1).unwrap();
2064 builder.append_null::<Float64Type>("B").unwrap();
2065 let expected_array = builder.build().unwrap();
2066
2067 compare_union_arrays(filtered, &expected_array);
2068 }
2069
2070 #[test]
2071 fn test_filter_union_array_sparse() {
2072 let mut builder = UnionBuilder::new_sparse();
2073 builder.append::<Int32Type>("A", 1).unwrap();
2074 builder.append::<Float64Type>("B", 3.2).unwrap();
2075 builder.append::<Int32Type>("A", 34).unwrap();
2076 let array = builder.build().unwrap();
2077
2078 test_filter_union_array(array);
2079 }
2080
2081 #[test]
2082 fn test_filter_union_array_sparse_with_nulls() {
2083 let mut builder = UnionBuilder::new_sparse();
2084 builder.append::<Int32Type>("A", 1).unwrap();
2085 builder.append::<Float64Type>("B", 3.2).unwrap();
2086 builder.append_null::<Float64Type>("B").unwrap();
2087 builder.append::<Int32Type>("A", 34).unwrap();
2088 let array = builder.build().unwrap();
2089
2090 let filter_array = BooleanArray::from(vec![true, false, true, false]);
2091 let c = filter(&array, &filter_array).unwrap();
2092 let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2093
2094 let mut builder = UnionBuilder::new_sparse();
2095 builder.append::<Int32Type>("A", 1).unwrap();
2096 builder.append_null::<Float64Type>("B").unwrap();
2097 let expected_array = builder.build().unwrap();
2098
2099 compare_union_arrays(filtered, &expected_array);
2100 }
2101
2102 fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2103 assert_eq!(union1.len(), union2.len());
2104
2105 for i in 0..union1.len() {
2106 let type_id = union1.type_id(i);
2107
2108 let slot1 = union1.value(i);
2109 let slot2 = union2.value(i);
2110
2111 assert_eq!(slot1.is_null(0), slot2.is_null(0));
2112
2113 if !slot1.is_null(0) && !slot2.is_null(0) {
2114 match type_id {
2115 0 => {
2116 let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2117 assert_eq!(slot1.len(), 1);
2118 let value1 = slot1.value(0);
2119
2120 let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2121 assert_eq!(slot2.len(), 1);
2122 let value2 = slot2.value(0);
2123 assert_eq!(value1, value2);
2124 }
2125 1 => {
2126 let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2127 assert_eq!(slot1.len(), 1);
2128 let value1 = slot1.value(0);
2129
2130 let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2131 assert_eq!(slot2.len(), 1);
2132 let value2 = slot2.value(0);
2133 assert_eq!(value1, value2);
2134 }
2135 _ => unreachable!(),
2136 }
2137 }
2138 }
2139 }
2140
2141 #[test]
2142 fn test_filter_struct() {
2143 let predicate = BooleanArray::from(vec![true, false, true, false]);
2144
2145 let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2146 let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2147
2148 let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2149 let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2150
2151 let null_mask = NullBuffer::from(vec![true, false, false, true]);
2152 let null_mask_filtered = NullBuffer::from(vec![true, false]);
2153
2154 let a_field = Field::new("a", DataType::Utf8, false);
2155 let b_field = Field::new("b", DataType::Int32, false);
2156
2157 let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2158 let expected =
2159 StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2160
2161 let result = filter(&array, &predicate).unwrap();
2162
2163 assert_eq!(result.to_data(), expected.to_data());
2164
2165 let array = StructArray::new(
2166 vec![a_field.clone()].into(),
2167 vec![a.clone()],
2168 Some(null_mask.clone()),
2169 );
2170 let expected = StructArray::new(
2171 vec![a_field.clone()].into(),
2172 vec![a_filtered.clone()],
2173 Some(null_mask_filtered.clone()),
2174 );
2175
2176 let result = filter(&array, &predicate).unwrap();
2177
2178 assert_eq!(result.to_data(), expected.to_data());
2179
2180 let array = StructArray::new(
2181 vec![a_field.clone(), b_field.clone()].into(),
2182 vec![a.clone(), b.clone()],
2183 None,
2184 );
2185 let expected = StructArray::new(
2186 vec![a_field.clone(), b_field.clone()].into(),
2187 vec![a_filtered.clone(), b_filtered.clone()],
2188 None,
2189 );
2190
2191 let result = filter(&array, &predicate).unwrap();
2192
2193 assert_eq!(result.to_data(), expected.to_data());
2194
2195 let array = StructArray::new(
2196 vec![a_field.clone(), b_field.clone()].into(),
2197 vec![a.clone(), b.clone()],
2198 Some(null_mask.clone()),
2199 );
2200
2201 let expected = StructArray::new(
2202 vec![a_field.clone(), b_field.clone()].into(),
2203 vec![a_filtered.clone(), b_filtered.clone()],
2204 Some(null_mask_filtered.clone()),
2205 );
2206
2207 let result = filter(&array, &predicate).unwrap();
2208
2209 assert_eq!(result.to_data(), expected.to_data());
2210 }
2211
2212 #[test]
2213 fn test_filter_empty_struct() {
2214 let fields = arrow_schema::Field::new(
2221 "a",
2222 arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2223 arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2224 arrow_schema::Field::new(
2225 "c",
2226 arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2227 true,
2228 ),
2229 ])),
2230 true,
2231 );
2232
2233 let schema = Arc::new(Schema::new(vec![fields]));
2241
2242 let b = Arc::new(Int64Array::from(vec![None, None, None]));
2243 let c = Arc::new(StructArray::new_empty_fields(
2244 3,
2245 Some(NullBuffer::from(vec![true, true, true])),
2246 ));
2247 let a = StructArray::new(
2248 vec![
2249 Field::new("b", DataType::Int64, true),
2250 Field::new("c", DataType::Struct(Fields::empty()), true),
2251 ]
2252 .into(),
2253 vec![b.clone(), c.clone()],
2254 Some(NullBuffer::from(vec![true, true, true])),
2255 );
2256 let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2257 println!("{record_batch:?}");
2258
2259 let predicate = BooleanArray::from(vec![true, false, true]);
2261 let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2262
2263 assert_eq!(filtered_batch.num_rows(), 2);
2265 }
2266
2267 #[test]
2268 #[should_panic]
2269 fn test_filter_bits_too_large() {
2270 let buffer = BooleanBuffer::from(vec![false; 8]);
2271 let predicate = BooleanArray::from(vec![true; 9]);
2272 let filter = FilterBuilder::new(&predicate).build();
2273 filter_bits(&buffer, &filter);
2274 }
2275
2276 #[test]
2277 #[should_panic]
2278 fn test_filter_native_too_large() {
2279 let values = vec![1; 8];
2280 let predicate = BooleanArray::from(vec![false; 9]);
2281 let filter = FilterBuilder::new(&predicate).build();
2282 filter_native(&values, &filter);
2283 }
2284}