polars_compute/filter/
mod.rs1mod boolean;
3mod primitive;
4mod scalar;
5
6#[cfg(all(target_arch = "x86_64", feature = "simd"))]
7mod avx512;
8
9use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder};
10use arrow::array::{
11 Array, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8ViewArray, new_empty_array,
12};
13use arrow::bitmap::Bitmap;
14use arrow::bitmap::utils::SlicesIterator;
15use arrow::with_match_primitive_type_full;
16pub use boolean::filter_boolean_kernel;
17
18pub fn filter(array: &dyn Array, mask: &BooleanArray) -> Box<dyn Array> {
19 assert_eq!(array.len(), mask.len());
20
21 if let Some(validities) = mask.validity() {
23 let combined_mask = mask.values() & validities;
24 filter_with_bitmap(array, &combined_mask)
25 } else {
26 filter_with_bitmap(array, mask.values())
27 }
28}
29
30pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box<dyn Array> {
31 let mut mask = mask.clone();
38 let leading_zeros = mask.take_leading_zeros();
39 mask.take_trailing_zeros();
40 let array = array.sliced(leading_zeros, mask.len());
41
42 let mask = &mask;
43 let array = array.as_ref();
44
45 let false_count = mask.unset_bits();
47 if false_count == mask.len() {
48 return new_empty_array(array.dtype().clone());
49 }
50 if false_count == 0 {
51 return array.to_boxed();
52 }
53
54 use arrow::datatypes::PhysicalType::*;
55 match array.dtype().to_physical_type() {
56 Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
57 let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap();
58 let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask);
59 Box::new(PrimitiveArray::from_vec(values).with_validity(validity))
60 }),
61 Boolean => {
62 let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
63 let (values, validity) =
64 boolean::filter_bitmap_and_validity(array.values(), array.validity(), mask);
65 BooleanArray::new(array.dtype().clone(), values, validity).boxed()
66 },
67 BinaryView => {
68 let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
69 let views = array.views();
70 let validity = array.validity();
71 let (views, validity) = primitive::filter_values_and_validity(views, validity, mask);
72 unsafe {
73 BinaryViewArray::new_unchecked_unknown_md(
74 array.dtype().clone(),
75 views.into(),
76 array.data_buffers().clone(),
77 validity,
78 Some(array.total_buffer_len()),
79 )
80 }
81 .boxed()
82 },
83 Utf8View => {
84 let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
85 let views = array.views();
86 let validity = array.validity();
87 let (views, validity) = primitive::filter_values_and_validity(views, validity, mask);
88 unsafe {
89 BinaryViewArray::new_unchecked_unknown_md(
90 arrow::datatypes::ArrowDataType::BinaryView,
91 views.into(),
92 array.data_buffers().clone(),
93 validity,
94 Some(array.total_buffer_len()),
95 )
96 .to_utf8view_unchecked()
97 }
98 .boxed()
99 },
100 _ => {
101 let iter = SlicesIterator::new(mask);
102 let mut mutable = make_builder(array.dtype());
103 mutable.reserve(iter.slots());
104 iter.for_each(|(start, len)| {
105 mutable.subslice_extend(array, start, len, ShareStrategy::Always)
106 });
107 mutable.freeze()
108 },
109 }
110}