arrow2/compute/
filter.rs

1//! Contains operators to filter arrays such as [`filter`].
2use crate::array::growable::{make_growable, Growable};
3use crate::bitmap::utils::{BitChunkIterExact, BitChunksExact};
4use crate::bitmap::{utils::SlicesIterator, Bitmap, MutableBitmap};
5use crate::chunk::Chunk;
6use crate::datatypes::DataType;
7use crate::error::Result;
8use crate::types::simd::Simd;
9use crate::types::BitChunkOnes;
10use crate::{array::*, types::NativeType};
11
12/// Function that can filter arbitrary arrays
13pub type Filter<'a> = Box<dyn Fn(&dyn Array) -> Box<dyn Array> + 'a + Send + Sync>;
14
15#[inline]
16fn get_leading_ones(chunk: u64) -> u32 {
17    if cfg!(target_endian = "little") {
18        chunk.trailing_ones()
19    } else {
20        chunk.leading_ones()
21    }
22}
23
24/// # Safety
25/// This assumes that the `mask_chunks` contains a number of set/true items equal
26/// to `filter_count`
27unsafe fn nonnull_filter_impl<T, I>(values: &[T], mut mask_chunks: I, filter_count: usize) -> Vec<T>
28where
29    T: NativeType + Simd,
30    I: BitChunkIterExact<u64>,
31{
32    let mut chunks = values.chunks_exact(64);
33    let mut new = Vec::<T>::with_capacity(filter_count);
34    let mut dst = new.as_mut_ptr();
35
36    chunks
37        .by_ref()
38        .zip(mask_chunks.by_ref())
39        .for_each(|(chunk, mask_chunk)| {
40            let ones = mask_chunk.count_ones();
41            let leading_ones = get_leading_ones(mask_chunk);
42
43            if ones == leading_ones {
44                let size = leading_ones as usize;
45                unsafe {
46                    std::ptr::copy(chunk.as_ptr(), dst, size);
47                    dst = dst.add(size);
48                }
49                return;
50            }
51
52            let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize);
53            for pos in ones_iter {
54                dst.write(*chunk.get_unchecked(pos));
55                dst = dst.add(1);
56            }
57        });
58
59    chunks
60        .remainder()
61        .iter()
62        .zip(mask_chunks.remainder_iter())
63        .for_each(|(value, b)| {
64            if b {
65                unsafe {
66                    dst.write(*value);
67                    dst = dst.add(1);
68                };
69            }
70        });
71
72    unsafe { new.set_len(filter_count) };
73    new
74}
75
76/// # Safety
77/// This assumes that the `mask_chunks` contains a number of set/true items equal
78/// to `filter_count`
79unsafe fn null_filter_impl<T, I>(
80    values: &[T],
81    validity: &Bitmap,
82    mut mask_chunks: I,
83    filter_count: usize,
84) -> (Vec<T>, MutableBitmap)
85where
86    T: NativeType + Simd,
87    I: BitChunkIterExact<u64>,
88{
89    let mut chunks = values.chunks_exact(64);
90
91    let mut validity_chunks = validity.chunks::<u64>();
92
93    let mut new = Vec::<T>::with_capacity(filter_count);
94    let mut dst = new.as_mut_ptr();
95    let mut new_validity = MutableBitmap::with_capacity(filter_count);
96
97    chunks
98        .by_ref()
99        .zip(validity_chunks.by_ref())
100        .zip(mask_chunks.by_ref())
101        .for_each(|((chunk, validity_chunk), mask_chunk)| {
102            let ones = mask_chunk.count_ones();
103            let leading_ones = get_leading_ones(mask_chunk);
104
105            if ones == leading_ones {
106                let size = leading_ones as usize;
107                unsafe {
108                    std::ptr::copy(chunk.as_ptr(), dst, size);
109                    dst = dst.add(size);
110
111                    // safety: invariant offset + length <= slice.len()
112                    new_validity.extend_from_slice_unchecked(
113                        validity_chunk.to_ne_bytes().as_ref(),
114                        0,
115                        size,
116                    );
117                }
118                return;
119            }
120
121            // this triggers a bitcount
122            let ones_iter = BitChunkOnes::from_known_count(mask_chunk, ones as usize);
123            for pos in ones_iter {
124                dst.write(*chunk.get_unchecked(pos));
125                dst = dst.add(1);
126                new_validity.push_unchecked(validity_chunk & (1 << pos) > 0);
127            }
128        });
129
130    chunks
131        .remainder()
132        .iter()
133        .zip(validity_chunks.remainder_iter())
134        .zip(mask_chunks.remainder_iter())
135        .for_each(|((value, is_valid), is_selected)| {
136            if is_selected {
137                unsafe {
138                    dst.write(*value);
139                    dst = dst.add(1);
140                    new_validity.push_unchecked(is_valid);
141                };
142            }
143        });
144
145    unsafe { new.set_len(filter_count) };
146    (new, new_validity)
147}
148
149fn null_filter_simd<T: NativeType + Simd>(
150    values: &[T],
151    validity: &Bitmap,
152    mask: &Bitmap,
153) -> (Vec<T>, MutableBitmap) {
154    assert_eq!(values.len(), mask.len());
155    let filter_count = mask.len() - mask.unset_bits();
156
157    let (slice, offset, length) = mask.as_slice();
158    if offset == 0 {
159        let mask_chunks = BitChunksExact::<u64>::new(slice, length);
160        unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) }
161    } else {
162        let mask_chunks = mask.chunks::<u64>();
163        unsafe { null_filter_impl(values, validity, mask_chunks, filter_count) }
164    }
165}
166
167fn nonnull_filter_simd<T: NativeType + Simd>(values: &[T], mask: &Bitmap) -> Vec<T> {
168    assert_eq!(values.len(), mask.len());
169    let filter_count = mask.len() - mask.unset_bits();
170
171    let (slice, offset, length) = mask.as_slice();
172    if offset == 0 {
173        let mask_chunks = BitChunksExact::<u64>::new(slice, length);
174        unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) }
175    } else {
176        let mask_chunks = mask.chunks::<u64>();
177        unsafe { nonnull_filter_impl(values, mask_chunks, filter_count) }
178    }
179}
180
181fn filter_nonnull_primitive<T: NativeType + Simd>(
182    array: &PrimitiveArray<T>,
183    mask: &Bitmap,
184) -> PrimitiveArray<T> {
185    assert_eq!(array.len(), mask.len());
186
187    if let Some(validity) = array.validity() {
188        let (values, validity) = null_filter_simd(array.values(), validity, mask);
189        PrimitiveArray::<T>::new(array.data_type().clone(), values.into(), validity.into())
190    } else {
191        let values = nonnull_filter_simd(array.values(), mask);
192        PrimitiveArray::<T>::new(array.data_type().clone(), values.into(), None)
193    }
194}
195
196fn filter_primitive<T: NativeType + Simd>(
197    array: &PrimitiveArray<T>,
198    mask: &BooleanArray,
199) -> PrimitiveArray<T> {
200    // todo: branch on mask.validity()
201    filter_nonnull_primitive(array, mask.values())
202}
203
204fn filter_growable<'a>(growable: &mut impl Growable<'a>, chunks: &[(usize, usize)]) {
205    chunks
206        .iter()
207        .for_each(|(start, len)| growable.extend(0, *start, *len));
208}
209
210/// Returns a prepared function optimized to filter multiple arrays.
211/// Creating this function requires time, but using it is faster than [filter] when the
212/// same filter needs to be applied to multiple arrays (e.g. a multiple columns).
213pub fn build_filter(filter: &BooleanArray) -> Result<Filter> {
214    let iter = SlicesIterator::new(filter.values());
215    let filter_count = iter.slots();
216    let chunks = iter.collect::<Vec<_>>();
217
218    use crate::datatypes::PhysicalType::*;
219    Ok(Box::new(move |array: &dyn Array| {
220        match array.data_type().to_physical_type() {
221            Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
222                let array = array.as_any().downcast_ref().unwrap();
223                let mut growable =
224                    growable::GrowablePrimitive::<$T>::new(vec![array], false, filter_count);
225                filter_growable(&mut growable, &chunks);
226                let array: PrimitiveArray<$T> = growable.into();
227                Box::new(array)
228            }),
229            Utf8 => {
230                let array = array.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
231                let mut growable = growable::GrowableUtf8::new(vec![array], false, filter_count);
232                filter_growable(&mut growable, &chunks);
233                let array: Utf8Array<i32> = growable.into();
234                Box::new(array)
235            }
236            LargeUtf8 => {
237                let array = array.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
238                let mut growable = growable::GrowableUtf8::new(vec![array], false, filter_count);
239                filter_growable(&mut growable, &chunks);
240                let array: Utf8Array<i64> = growable.into();
241                Box::new(array)
242            }
243            _ => {
244                let mut mutable = make_growable(&[array], false, filter_count);
245                chunks
246                    .iter()
247                    .for_each(|(start, len)| mutable.extend(0, *start, *len));
248                mutable.as_box()
249            }
250        }
251    }))
252}
253
254/// Filters an [Array], returning elements matching the filter (i.e. where the values are true).
255///
256/// Note that the nulls of `filter` are interpreted as `false` will lead to these elements being
257/// masked out.
258///
259/// # Example
260/// ```rust
261/// # use arrow2::array::{Int32Array, PrimitiveArray, BooleanArray};
262/// # use arrow2::error::Result;
263/// # use arrow2::compute::filter::filter;
264/// # fn main() -> Result<()> {
265/// let array = PrimitiveArray::from_slice([5, 6, 7, 8, 9]);
266/// let filter_array = BooleanArray::from_slice(&vec![true, false, false, true, false]);
267/// let c = filter(&array, &filter_array)?;
268/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
269/// assert_eq!(c, &PrimitiveArray::from_slice(vec![5, 8]));
270/// # Ok(())
271/// # }
272/// ```
273pub fn filter(array: &dyn Array, filter: &BooleanArray) -> Result<Box<dyn Array>> {
274    // The validities may be masking out `true` bits, making the filter operation
275    // based on the values incorrect
276    if let Some(validities) = filter.validity() {
277        let values = filter.values();
278        let new_values = values & validities;
279        let filter = BooleanArray::new(DataType::Boolean, new_values, None);
280        return crate::compute::filter::filter(array, &filter);
281    }
282
283    let false_count = filter.values().unset_bits();
284    if false_count == filter.len() {
285        assert_eq!(array.len(), filter.len());
286        return Ok(new_empty_array(array.data_type().clone()));
287    }
288    if false_count == 0 {
289        assert_eq!(array.len(), filter.len());
290        return Ok(array.to_boxed());
291    }
292
293    use crate::datatypes::PhysicalType::*;
294    match array.data_type().to_physical_type() {
295        Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
296            let array = array.as_any().downcast_ref().unwrap();
297            Ok(Box::new(filter_primitive::<$T>(array, filter)))
298        }),
299        _ => {
300            let iter = SlicesIterator::new(filter.values());
301            let mut mutable = make_growable(&[array], false, iter.slots());
302            iter.for_each(|(start, len)| mutable.extend(0, start, len));
303            Ok(mutable.as_box())
304        }
305    }
306}
307
308/// Returns a new [Chunk] with arrays containing only values matching the filter.
309/// This is a convenience function: filter multiple columns is embarassingly parallel.
310pub fn filter_chunk<A: AsRef<dyn Array>>(
311    columns: &Chunk<A>,
312    filter_values: &BooleanArray,
313) -> Result<Chunk<Box<dyn Array>>> {
314    let arrays = columns.arrays();
315
316    let num_colums = arrays.len();
317
318    let filtered_arrays = match num_colums {
319        1 => {
320            vec![filter(columns.arrays()[0].as_ref(), filter_values)?]
321        }
322        _ => {
323            let filter = build_filter(filter_values)?;
324            arrays.iter().map(|a| filter(a.as_ref())).collect()
325        }
326    };
327    Chunk::try_new(filtered_arrays)
328}