polars_compute/filter/
mod.rs

1//! Contains operators to filter arrays such as [`filter`].
2mod boolean;
3mod primitive;
4mod scalar;
5
6#[cfg(all(target_arch = "x86_64", feature = "simd"))]
7mod avx512;
8
9use arrow::array::builder::{ArrayBuilder, ShareStrategy, make_builder};
10use arrow::array::{
11    Array, BinaryViewArray, BooleanArray, PrimitiveArray, Utf8ViewArray, new_empty_array,
12};
13use arrow::bitmap::Bitmap;
14use arrow::bitmap::utils::SlicesIterator;
15use arrow::with_match_primitive_type_full;
16pub use boolean::filter_boolean_kernel;
17
18pub fn filter(array: &dyn Array, mask: &BooleanArray) -> Box<dyn Array> {
19    assert_eq!(array.len(), mask.len());
20
21    // Treat null mask values as false.
22    if let Some(validities) = mask.validity() {
23        let combined_mask = mask.values() & validities;
24        filter_with_bitmap(array, &combined_mask)
25    } else {
26        filter_with_bitmap(array, mask.values())
27    }
28}
29
30pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box<dyn Array> {
31    // Many filters involve filtering values in a subsection of the array. When we trim the leading
32    // and trailing filtered items, we can close in on those items and not have to perform and
33    // thinking about those. The overhead for when there are no leading or trailing filtered values
34    // is very minimal: only a clone of the mask and the array.
35    //
36    // This also allows dispatching to the fast paths way, way, way more often.
37    let mut mask = mask.clone();
38    let leading_zeros = mask.take_leading_zeros();
39    mask.take_trailing_zeros();
40    let array = array.sliced(leading_zeros, mask.len());
41
42    let mask = &mask;
43    let array = array.as_ref();
44
45    // Fast-path: completely empty or completely full mask.
46    let false_count = mask.unset_bits();
47    if false_count == mask.len() {
48        return new_empty_array(array.dtype().clone());
49    }
50    if false_count == 0 {
51        return array.to_boxed();
52    }
53
54    use arrow::datatypes::PhysicalType::*;
55    match array.dtype().to_physical_type() {
56        Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
57            let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap();
58            let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask);
59            Box::new(PrimitiveArray::from_vec(values).with_validity(validity))
60        }),
61        Boolean => {
62            let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
63            let (values, validity) =
64                boolean::filter_bitmap_and_validity(array.values(), array.validity(), mask);
65            BooleanArray::new(array.dtype().clone(), values, validity).boxed()
66        },
67        BinaryView => {
68            let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
69            let views = array.views();
70            let validity = array.validity();
71            let (views, validity) = primitive::filter_values_and_validity(views, validity, mask);
72            unsafe {
73                BinaryViewArray::new_unchecked_unknown_md(
74                    array.dtype().clone(),
75                    views.into(),
76                    array.data_buffers().clone(),
77                    validity,
78                    Some(array.total_buffer_len()),
79                )
80            }
81            .boxed()
82        },
83        Utf8View => {
84            let array = array.as_any().downcast_ref::<Utf8ViewArray>().unwrap();
85            let views = array.views();
86            let validity = array.validity();
87            let (views, validity) = primitive::filter_values_and_validity(views, validity, mask);
88            unsafe {
89                BinaryViewArray::new_unchecked_unknown_md(
90                    arrow::datatypes::ArrowDataType::BinaryView,
91                    views.into(),
92                    array.data_buffers().clone(),
93                    validity,
94                    Some(array.total_buffer_len()),
95                )
96                .to_utf8view_unchecked()
97            }
98            .boxed()
99        },
100        _ => {
101            let iter = SlicesIterator::new(mask);
102            let mut mutable = make_builder(array.dtype());
103            mutable.reserve(iter.slots());
104            iter.for_each(|(start, len)| {
105                mutable.subslice_extend(array, start, len, ShareStrategy::Always)
106            });
107            mutable.freeze()
108        },
109    }
110}