Skip to main content

arrow_select/
filter.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines filter kernels
19
20use std::ops::AddAssign;
21use std::sync::Arc;
22
23use arrow_array::builder::BooleanBufferBuilder;
24use arrow_array::cast::AsArray;
25use arrow_array::types::{
26    ArrowDictionaryKeyType, ArrowPrimitiveType, ByteArrayType, ByteViewType, RunEndIndexType,
27};
28use arrow_array::*;
29use arrow_buffer::{
30    ArrowNativeType, BooleanBuffer, NullBuffer, OffsetBuffer, RunEndBuffer, ScalarBuffer, bit_util,
31};
32use arrow_buffer::{Buffer, MutableBuffer};
33use arrow_data::bit_iterator::{BitIndexIterator, BitSliceIterator};
34use arrow_data::transform::MutableArrayData;
35use arrow_schema::*;
36
37/// If the filter selects more than this fraction of rows, use
38/// [`SlicesIterator`] to copy ranges of values. Otherwise iterate
39/// over individual rows using [`IndexIterator`]
40///
41/// Threshold of 0.8 chosen based on <https://dl.acm.org/doi/abs/10.1145/3465998.3466009>
42///
43const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
44
45/// An iterator of `(usize, usize)` each representing an interval
46/// `[start, end)` whose slots of a bitmap [Buffer] are true.
47///
48/// Each interval corresponds to a contiguous region of memory to be
49/// "taken" from an array to be filtered.
50///
51/// ## Notes:
52///
53/// 1. Ignores the validity bitmap (ignores nulls)
54///
55/// 2. Only performant for filters that copy across long contiguous runs
56#[derive(Debug)]
57pub struct SlicesIterator<'a>(BitSliceIterator<'a>);
58
59impl<'a> SlicesIterator<'a> {
60    /// Creates a new iterator from a [BooleanArray]
61    pub fn new(filter: &'a BooleanArray) -> Self {
62        filter.values().into()
63    }
64}
65
66impl<'a> From<&'a BooleanBuffer> for SlicesIterator<'a> {
67    fn from(filter: &'a BooleanBuffer) -> Self {
68        Self(filter.set_slices())
69    }
70}
71
72impl Iterator for SlicesIterator<'_> {
73    type Item = (usize, usize);
74
75    fn next(&mut self) -> Option<Self::Item> {
76        self.0.next()
77    }
78}
79
80/// An iterator of `usize` whose index in [`BooleanArray`] is true
81///
82/// This provides the best performance on most predicates, apart from those which keep
83/// large runs and therefore favour [`SlicesIterator`]
84struct IndexIterator<'a> {
85    remaining: usize,
86    iter: BitIndexIterator<'a>,
87}
88
89impl<'a> IndexIterator<'a> {
90    fn new(filter: &'a BooleanArray, remaining: usize) -> Self {
91        assert_eq!(filter.null_count(), 0);
92        let iter = filter.values().set_indices();
93        Self { remaining, iter }
94    }
95
96    /// Collect this iterator as a [`Vec`]
97    /// This is more efficient than the standard `collect` as we can
98    /// pre-allocate the entire uninitialized buffer and then fill it (roughly 1.6x faster)
99    pub fn collect(mut self) -> Vec<usize> {
100        let len = self.remaining;
101        let mut result = Vec::with_capacity(len);
102        let ptr: *mut usize = result.as_mut_ptr();
103        for i in 0..len {
104            // SAFETY: we have allocated enough space in `result` and remaining
105            // correctly tracks the number of elements
106            let next = self.iter.next();
107            debug_assert!(next.is_some(), "IndexIterator exhausted early");
108            unsafe {
109                *ptr.add(i) = next.unwrap_unchecked();
110            }
111        }
112        // SAFETY: we have initialized `len` elements
113        unsafe {
114            result.set_len(len);
115        }
116        result
117    }
118}
119
120impl Iterator for IndexIterator<'_> {
121    type Item = usize;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        if self.remaining != 0 {
125            // Fascinatingly swapping these two lines around results in a 50%
126            // performance regression for some benchmarks
127            let next = self.iter.next().expect("IndexIterator exhausted early");
128            self.remaining -= 1;
129            // Must panic if exhausted early as trusted length iterator
130            return Some(next);
131        }
132        None
133    }
134
135    fn size_hint(&self) -> (usize, Option<usize>) {
136        (self.remaining, Some(self.remaining))
137    }
138}
139
140/// Counts the number of set bits in `filter`
141fn filter_count(filter: &BooleanArray) -> usize {
142    filter.values().count_set_bits()
143}
144
145/// Convert all null values in `BooleanArray` to `false`
146///
147/// This is useful for filter-like operations which select only `true`
148/// values, but not `false` or `NULL` values
149///
150/// Internally this is implemented as a bitwise `AND` operation with null bits
151/// and the boolean bits.
152///
153/// # Example
154/// ```
155/// # use arrow_array::{Array, BooleanArray};
156/// # use arrow_select::filter::prep_null_mask_filter;
157/// let filter = BooleanArray::from(vec![
158///   Some(true),
159///   Some(false),
160///   None
161/// ]);
162/// // convert Boolean array to a filter mask
163/// let null_mask = prep_null_mask_filter(&filter);
164/// // there are no nulls in the output mask
165/// assert!(null_mask.nulls().is_none());
166/// assert_eq!(null_mask, BooleanArray::from(vec![
167///  true,
168///  false,
169///  false, // Null is converted to false
170/// ]));
171/// ```
172pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
173    let nulls = filter.nulls().unwrap();
174    let mask = filter.values() & nulls.inner();
175    BooleanArray::new(mask, None)
176}
177
178/// Returns a filtered `values` [`Array`] where the corresponding elements of
179/// `predicate` are `true`.
180///
181/// If multiple arrays (or record batches) need to be filtered using the same predicate array,
182/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
183/// calling [FilterPredicate::filter_record_batch].
184///
185/// In contrast to this function, it is then the responsibility of the caller
186/// to use [FilterBuilder::optimize] if appropriate.
187///
188/// # See also
189/// * [`FilterBuilder`] for more control over the filtering process.
190/// * [`filter_record_batch`] to filter a [`RecordBatch`]
191/// * [`BatchCoalescer`]: to filter multiple [`RecordBatch`] and coalesce
192///   the results into a single array.
193///
194/// [`BatchCoalescer`]: crate::coalesce::BatchCoalescer
195///
196/// # Example
197/// ```rust
198/// # use arrow_array::{Int32Array, BooleanArray};
199/// # use arrow_select::filter::filter;
200/// let array = Int32Array::from(vec![5, 6, 7, 8, 9]);
201/// let filter_array = BooleanArray::from(vec![true, false, false, true, false]);
202/// let c = filter(&array, &filter_array).unwrap();
203/// let c = c.as_any().downcast_ref::<Int32Array>().unwrap();
204/// assert_eq!(c, &Int32Array::from(vec![5, 8]));
205/// ```
206pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef, ArrowError> {
207    let mut filter_builder = FilterBuilder::new(predicate);
208
209    if FilterBuilder::is_optimize_beneficial(values.data_type()) {
210        // Only optimize if filtering more than one array
211        // Otherwise, the overhead of optimization can be more than the benefit
212        filter_builder = filter_builder.optimize();
213    }
214
215    let predicate = filter_builder.build();
216
217    filter_array(values, &predicate)
218}
219
220/// Returns a filtered [RecordBatch] where the corresponding elements of
221/// `predicate` are true.
222///
223/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
224///
225/// If multiple record batches (or arrays) need to be filtered using the same predicate array,
226/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
227/// calling [FilterPredicate::filter_record_batch].
228/// In contrast to this function, it is then the responsibility of the caller
229/// to use [FilterBuilder::optimize] if appropriate.
230pub fn filter_record_batch(
231    record_batch: &RecordBatch,
232    predicate: &BooleanArray,
233) -> Result<RecordBatch, ArrowError> {
234    let mut filter_builder = FilterBuilder::new(predicate);
235    let num_cols = record_batch.num_columns();
236    if num_cols > 1
237        || (num_cols > 0
238            && FilterBuilder::is_optimize_beneficial(
239                record_batch.schema_ref().field(0).data_type(),
240            ))
241    {
242        // Only optimize if filtering more than one column or if the column contains multiple internal arrays
243        // Otherwise, the overhead of optimization can be more than the benefit
244        filter_builder = filter_builder.optimize();
245    }
246    let filter = filter_builder.build();
247
248    filter.filter_record_batch(record_batch)
249}
250
251/// A builder to construct [`FilterPredicate`]
252#[derive(Debug)]
253pub struct FilterBuilder {
254    filter: BooleanArray,
255    count: usize,
256    strategy: IterationStrategy,
257}
258
259impl FilterBuilder {
260    /// Create a new [`FilterBuilder`] that can be used to construct a [`FilterPredicate`]
261    pub fn new(filter: &BooleanArray) -> Self {
262        let filter = match filter.null_count() {
263            0 => filter.clone(),
264            _ => prep_null_mask_filter(filter),
265        };
266
267        let count = filter_count(&filter);
268        let strategy = IterationStrategy::default_strategy(filter.len(), count);
269
270        Self {
271            filter,
272            count,
273            strategy,
274        }
275    }
276
277    /// Compute an optimized representation of the provided `filter` mask that can be
278    /// applied to an array more quickly.
279    ///
280    /// When filtering multiple arrays (e.g. a [`RecordBatch`] or a
281    /// [`StructArray`] with multiple fields), optimizing the filter can provide
282    /// significant performance benefits.
283    ///
284    /// However, optimization takes time and can have a larger memory footprint
285    /// than the original mask, so it is often faster to filter a single array,
286    /// without filter optimization.
287    pub fn optimize(mut self) -> Self {
288        match self.strategy {
289            IterationStrategy::SlicesIterator => {
290                let slices = SlicesIterator::new(&self.filter).collect();
291                self.strategy = IterationStrategy::Slices(slices)
292            }
293            IterationStrategy::IndexIterator => {
294                let indices = IndexIterator::new(&self.filter, self.count).collect();
295                self.strategy = IterationStrategy::Indices(indices)
296            }
297            _ => {}
298        }
299        self
300    }
301
302    /// Determines if calling [FilterBuilder::optimize] is beneficial for the
303    /// given type even when filtering just a single array.
304    ///
305    /// See [`FilterBuilder::optimize`] for more details.
306    pub fn is_optimize_beneficial(data_type: &DataType) -> bool {
307        match data_type {
308            DataType::Struct(fields) => {
309                fields.len() > 1
310                    || fields.len() == 1
311                        && FilterBuilder::is_optimize_beneficial(fields[0].data_type())
312            }
313            DataType::Union(fields, UnionMode::Sparse) => !fields.is_empty(),
314            _ => false,
315        }
316    }
317
318    /// Construct the final `FilterPredicate`
319    pub fn build(self) -> FilterPredicate {
320        FilterPredicate {
321            filter: self.filter,
322            count: self.count,
323            strategy: self.strategy,
324        }
325    }
326}
327
328/// The iteration strategy used to evaluate [`FilterPredicate`]
329#[derive(Debug)]
330enum IterationStrategy {
331    /// A lazily evaluated iterator of ranges
332    SlicesIterator,
333    /// A lazily evaluated iterator of indices
334    IndexIterator,
335    /// A precomputed list of indices
336    Indices(Vec<usize>),
337    /// A precomputed array of ranges
338    Slices(Vec<(usize, usize)>),
339    /// Select all rows
340    All,
341    /// Select no rows
342    None,
343}
344
345impl IterationStrategy {
346    /// The default [`IterationStrategy`] for a filter of length `filter_length`
347    /// and selecting `filter_count` rows
348    fn default_strategy(filter_length: usize, filter_count: usize) -> Self {
349        if filter_length == 0 || filter_count == 0 {
350            return IterationStrategy::None;
351        }
352
353        if filter_count == filter_length {
354            return IterationStrategy::All;
355        }
356
357        // Compute the selectivity of the predicate by dividing the number of true
358        // bits in the predicate by the predicate's total length
359        //
360        // This can then be used as a heuristic for the optimal iteration strategy
361        let selectivity_frac = filter_count as f64 / filter_length as f64;
362        if selectivity_frac > FILTER_SLICES_SELECTIVITY_THRESHOLD {
363            return IterationStrategy::SlicesIterator;
364        }
365        IterationStrategy::IndexIterator
366    }
367}
368
369/// A filtering predicate that can be applied to an [`Array`]
370#[derive(Debug)]
371pub struct FilterPredicate {
372    filter: BooleanArray,
373    count: usize,
374    strategy: IterationStrategy,
375}
376
377impl FilterPredicate {
378    /// Selects rows from `values` based on this [`FilterPredicate`]
379    pub fn filter(&self, values: &dyn Array) -> Result<ArrayRef, ArrowError> {
380        filter_array(values, self)
381    }
382
383    /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
384    /// [`FilterPredicate`].
385    ///
386    /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
387    pub fn filter_record_batch(
388        &self,
389        record_batch: &RecordBatch,
390    ) -> Result<RecordBatch, ArrowError> {
391        let filtered_arrays = record_batch
392            .columns()
393            .iter()
394            .map(|a| filter_array(a, self))
395            .collect::<Result<Vec<_>, _>>()?;
396
397        // SAFETY: we know that the set of filtered arrays will match the schema of the original
398        // record batch
399        unsafe {
400            Ok(RecordBatch::new_unchecked(
401                record_batch.schema(),
402                filtered_arrays,
403                self.count,
404            ))
405        }
406    }
407
408    /// Number of rows being selected based on this [`FilterPredicate`]
409    pub fn count(&self) -> usize {
410        self.count
411    }
412
413    /// Filters the given `nulls` buffer using this predicate.
414    ///
415    /// Returns `None` when there is nothing to track in the output, either
416    /// because the input `nulls` was `None`, the input had no nulls, or the
417    /// filtered result has no nulls. Otherwise returns the filtered
418    /// [`NullBuffer`] with its precomputed null count.
419    pub fn filter_nulls(&self, nulls: Option<&NullBuffer>) -> Option<NullBuffer> {
420        let (null_count, nulls) = filter_null_mask(nulls, self)?;
421        let buffer = BooleanBuffer::new(nulls, 0, self.count);
422
423        debug_assert_eq!(null_count, buffer.len() - buffer.count_set_bits());
424        // SAFETY: `filter_null_mask` derived `null_count` from `buffer`, so it
425        // matches the number of unset bits as required by `new_unchecked`.
426        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
427    }
428}
429
430fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result<ArrayRef, ArrowError> {
431    if predicate.filter.len() > values.len() {
432        return Err(ArrowError::InvalidArgumentError(format!(
433            "Filter predicate of length {} is larger than target array of length {}",
434            predicate.filter.len(),
435            values.len()
436        )));
437    }
438
439    match predicate.strategy {
440        IterationStrategy::None => Ok(new_empty_array(values.data_type())),
441        IterationStrategy::All => Ok(values.slice(0, predicate.count)),
442        // actually filter
443        _ => downcast_primitive_array! {
444            values => Ok(Arc::new(filter_primitive(values, predicate))),
445            DataType::Boolean => {
446                let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
447                Ok(Arc::new(filter_boolean(values, predicate)))
448            }
449            DataType::Utf8 => {
450                Ok(Arc::new(filter_bytes(values.as_string::<i32>(), predicate)))
451            }
452            DataType::LargeUtf8 => {
453                Ok(Arc::new(filter_bytes(values.as_string::<i64>(), predicate)))
454            }
455            DataType::Utf8View => {
456                Ok(Arc::new(filter_byte_view(values.as_string_view(), predicate)))
457            }
458            DataType::Binary => {
459                Ok(Arc::new(filter_bytes(values.as_binary::<i32>(), predicate)))
460            }
461            DataType::LargeBinary => {
462                Ok(Arc::new(filter_bytes(values.as_binary::<i64>(), predicate)))
463            }
464            DataType::BinaryView => {
465                Ok(Arc::new(filter_byte_view(values.as_binary_view(), predicate)))
466            }
467            DataType::FixedSizeBinary(_) => {
468                Ok(Arc::new(filter_fixed_size_binary(values.as_fixed_size_binary(), predicate)))
469            }
470            DataType::ListView(_) => {
471                Ok(Arc::new(filter_list_view::<i32>(values.as_list_view(), predicate)))
472            }
473            DataType::LargeListView(_) => {
474                Ok(Arc::new(filter_list_view::<i64>(values.as_list_view(), predicate)))
475            }
476            DataType::RunEndEncoded(_, _) => {
477                downcast_run_array!{
478                    values => Ok(Arc::new(filter_run_end_array(values, predicate)?)),
479                    t => unimplemented!("Filter not supported for RunEndEncoded type {:?}", t)
480                }
481            }
482            DataType::Dictionary(_, _) => downcast_dictionary_array! {
483                values => Ok(Arc::new(filter_dict(values, predicate))),
484                t => unimplemented!("Filter not supported for dictionary type {:?}", t)
485            }
486            DataType::Struct(_) => {
487                Ok(Arc::new(filter_struct(values.as_struct(), predicate)?))
488            }
489            DataType::Union(_, UnionMode::Sparse) => {
490                Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?))
491            }
492            _ => {
493                let data = values.to_data();
494                // fallback to using MutableArrayData
495                let mut mutable = MutableArrayData::new(
496                    vec![&data],
497                    false,
498                    predicate.count,
499                );
500
501                match &predicate.strategy {
502                    IterationStrategy::Slices(slices) => {
503                        slices
504                            .iter()
505                            .for_each(|(start, end)| mutable.extend(0, *start, *end));
506                    }
507                    _ => {
508                        let iter = SlicesIterator::new(&predicate.filter);
509                        iter.for_each(|(start, end)| mutable.extend(0, start, end));
510                    }
511                }
512
513                let data = mutable.freeze();
514                Ok(make_array(data))
515            }
516        },
517    }
518}
519
520/// Filter any supported [`RunArray`] based on a [`FilterPredicate`]
521fn filter_run_end_array<R: RunEndIndexType>(
522    array: &RunArray<R>,
523    predicate: &FilterPredicate,
524) -> Result<RunArray<R>, ArrowError>
525where
526    R::Native: Into<i64> + From<bool>,
527    R::Native: AddAssign,
528{
529    let run_ends: &RunEndBuffer<R::Native> = array.run_ends();
530    let start_physical = run_ends.get_start_physical_index();
531    let end_physical = run_ends.get_end_physical_index();
532    let physical_len = end_physical - start_physical + 1;
533
534    let mut new_run_ends = vec![R::default_value(); physical_len];
535    let offset = run_ends.offset() as u64;
536
537    let mut start = 0u64;
538    let mut j = 0;
539    let mut count = R::default_value();
540    let filter_values = predicate.filter.values();
541    let run_ends = run_ends.inner();
542
543    let pred: BooleanArray = BooleanBuffer::collect_bool(physical_len, |i| {
544        let mut keep = false;
545        let mut end = (run_ends[i + start_physical].into() as u64).saturating_sub(offset);
546        let difference = end.saturating_sub(filter_values.len() as u64);
547        end -= difference;
548
549        // Safety: we subtract the difference off `end` so we are always within bounds
550        for pred in (start..end).map(|i| unsafe { filter_values.value_unchecked(i as usize) }) {
551            count += R::Native::from(pred);
552            keep |= pred
553        }
554        // this is to avoid branching
555        new_run_ends[j] = count;
556        j += keep as usize;
557
558        start = end;
559        keep
560    })
561    .into();
562
563    new_run_ends.truncate(j);
564
565    let values = array.values_slice();
566    let values = filter(values.as_ref(), &pred)?;
567
568    let run_ends = PrimitiveArray::<R>::try_new(new_run_ends.into(), None)?;
569    RunArray::try_new(&run_ends, &values)
570}
571
572/// Computes a new null mask for `data` based on `predicate`
573///
574/// If the predicate selected no null-rows, returns `None`, otherwise returns
575/// `Some((null_count, null_buffer))` where `null_count` is the number of nulls
576/// in the filtered output, and `null_buffer` is the filtered null buffer
577///
578fn filter_null_mask(
579    nulls: Option<&NullBuffer>,
580    predicate: &FilterPredicate,
581) -> Option<(usize, Buffer)> {
582    let nulls = nulls?;
583    if nulls.null_count() == 0 {
584        return None;
585    }
586
587    let nulls = filter_bits(nulls.inner(), predicate);
588    // The filtered `nulls` has a length of `predicate.count` bits and
589    // therefore the null count is this minus the number of valid bits
590    let null_count = predicate.count - nulls.count_set_bits_offset(0, predicate.count);
591
592    if null_count == 0 {
593        return None;
594    }
595
596    Some((null_count, nulls))
597}
598
599/// Filter the packed bitmask `buffer`, with `predicate` starting at bit offset `offset`
600fn filter_bits(buffer: &BooleanBuffer, predicate: &FilterPredicate) -> Buffer {
601    let src = buffer.values();
602    let offset = buffer.offset();
603    assert!(buffer.len() >= predicate.filter.len());
604
605    match &predicate.strategy {
606        IterationStrategy::IndexIterator => {
607            let bits =
608                // SAFETY: IndexIterator uses the filter predicate to derive indices
609                IndexIterator::new(&predicate.filter, predicate.count).map(|src_idx| unsafe {
610                    bit_util::get_bit_raw(buffer.values().as_ptr(), src_idx + offset)
611                });
612
613            // SAFETY: `IndexIterator` reports its size correctly
614            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
615        }
616        IterationStrategy::Indices(indices) => {
617            // SAFETY: indices were derived from the filter predicate
618            let bits = indices.iter().map(|src_idx| unsafe {
619                bit_util::get_bit_raw(buffer.values().as_ptr(), *src_idx + offset)
620            });
621            // SAFETY: `Vec::iter()` reports its size correctly
622            unsafe { MutableBuffer::from_trusted_len_iter_bool(bits).into() }
623        }
624        IterationStrategy::SlicesIterator => {
625            let mut builder = BooleanBufferBuilder::new(predicate.count);
626            for (start, end) in SlicesIterator::new(&predicate.filter) {
627                builder.append_packed_range(start + offset..end + offset, src)
628            }
629            builder.into()
630        }
631        IterationStrategy::Slices(slices) => {
632            let mut builder = BooleanBufferBuilder::new(predicate.count);
633            for (start, end) in slices {
634                builder.append_packed_range(*start + offset..*end + offset, src)
635            }
636            builder.into()
637        }
638        IterationStrategy::All | IterationStrategy::None => unreachable!(),
639    }
640}
641
642/// `filter` implementation for boolean buffers
643fn filter_boolean(array: &BooleanArray, predicate: &FilterPredicate) -> BooleanArray {
644    let buffer = filter_bits(array.values(), predicate);
645    let values = BooleanBuffer::new(buffer, 0, predicate.count);
646    let nulls = predicate.filter_nulls(array.nulls());
647
648    BooleanArray::new(values, nulls)
649}
650
651#[inline(never)]
652fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate) -> Buffer {
653    assert!(values.len() >= predicate.filter.len());
654
655    match &predicate.strategy {
656        IterationStrategy::SlicesIterator => {
657            let mut buffer = Vec::with_capacity(predicate.count);
658            for (start, end) in SlicesIterator::new(&predicate.filter) {
659                // SAFETY: indices were derived from the filter predicate
660                buffer.extend_from_slice(unsafe { values.get_unchecked(start..end) });
661            }
662            buffer.into()
663        }
664        IterationStrategy::Slices(slices) => {
665            let mut buffer = Vec::with_capacity(predicate.count);
666            for (start, end) in slices {
667                // SAFETY: indices were derived from the filter predicate
668                buffer.extend_from_slice(unsafe { values.get_unchecked(*start..*end) });
669            }
670            buffer.into()
671        }
672        IterationStrategy::IndexIterator => {
673            // SAFETY: indices were derived from the filter predicate
674            let iter = IndexIterator::new(&predicate.filter, predicate.count)
675                .map(|x| unsafe { *values.get_unchecked(x) });
676
677            // SAFETY: IndexIterator is trusted length
678            unsafe { MutableBuffer::from_trusted_len_iter(iter) }.into()
679        }
680        IterationStrategy::Indices(indices) => {
681            // SAFETY: indices were derived from the filter predicate
682            let iter = indices.iter().map(|x| unsafe { *values.get_unchecked(*x) });
683            iter.collect::<Vec<_>>().into()
684        }
685        IterationStrategy::All | IterationStrategy::None => unreachable!(),
686    }
687}
688
689/// `filter` implementation for primitive arrays
690fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
691where
692    T: ArrowPrimitiveType,
693{
694    let buffer = filter_native(array.values(), predicate);
695    let values = ScalarBuffer::new(buffer, 0, predicate.count);
696    let nulls = predicate.filter_nulls(array.nulls());
697    let filtered = PrimitiveArray::new(values, nulls);
698
699    // Avoid the compatibility check when the physical type already matches.
700    if array.data_type() == &T::DATA_TYPE {
701        filtered
702    } else {
703        filtered.with_data_type(array.data_type().clone())
704    }
705}
706
707/// [`FilterBytes`] is created from a source [`GenericByteArray`] and can be
708/// used to build a new [`GenericByteArray`] by copying values from the source
709///
710/// TODO(raphael): Could this be used for the take kernel as well?
711struct FilterBytes<'a, OffsetSize> {
712    src_offsets: &'a [OffsetSize],
713    src_values: &'a [u8],
714    dst_offsets: Vec<OffsetSize>,
715    dst_values: Vec<u8>,
716    cur_offset: OffsetSize,
717}
718
719impl<'a, OffsetSize> FilterBytes<'a, OffsetSize>
720where
721    OffsetSize: OffsetSizeTrait,
722{
723    fn new<T>(capacity: usize, array: &'a GenericByteArray<T>) -> Self
724    where
725        T: ByteArrayType<Offset = OffsetSize>,
726    {
727        let dst_values = Vec::new();
728        let mut dst_offsets: Vec<OffsetSize> = Vec::with_capacity(capacity + 1);
729        let cur_offset = OffsetSize::from_usize(0).unwrap();
730
731        dst_offsets.push(cur_offset);
732
733        Self {
734            src_offsets: array.value_offsets(),
735            src_values: array.value_data(),
736            dst_offsets,
737            dst_values,
738            cur_offset,
739        }
740    }
741
742    /// Returns the byte offset at `idx`
743    #[inline]
744    fn get_value_offset(&self, idx: usize) -> usize {
745        self.src_offsets[idx].as_usize()
746    }
747
748    /// Returns the start and end of the value at index `idx` along with its length
749    #[inline]
750    fn get_value_range(&self, idx: usize) -> (usize, usize, OffsetSize) {
751        // These can only fail if `array` contains invalid data
752        let start = self.get_value_offset(idx);
753        let end = self.get_value_offset(idx + 1);
754        let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
755        (start, end, len)
756    }
757
758    fn extend_offsets_idx(&mut self, iter: impl Iterator<Item = usize>) {
759        self.dst_offsets.extend(iter.map(|idx| {
760            let start = self.src_offsets[idx].as_usize();
761            let end = self.src_offsets[idx + 1].as_usize();
762            let len = OffsetSize::from_usize(end - start).expect("illegal offset range");
763            self.cur_offset += len;
764
765            self.cur_offset
766        }));
767    }
768
769    /// Extends the in-progress array by the indexes in the provided iterator
770    fn extend_idx(&mut self, iter: impl Iterator<Item = usize>) {
771        self.dst_values.reserve_exact(self.cur_offset.as_usize());
772
773        for idx in iter {
774            let start = self.src_offsets[idx].as_usize();
775            let end = self.src_offsets[idx + 1].as_usize();
776            self.dst_values
777                .extend_from_slice(&self.src_values[start..end]);
778        }
779    }
780
781    fn extend_offsets_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>, count: usize) {
782        self.dst_offsets.reserve_exact(count);
783        for (start, end) in iter {
784            // These can only fail if `array` contains invalid data
785            for idx in start..end {
786                let (_, _, len) = self.get_value_range(idx);
787                self.cur_offset += len;
788                self.dst_offsets.push(self.cur_offset);
789            }
790        }
791    }
792
793    /// Extends the in-progress array by the ranges in the provided iterator
794    fn extend_slices(&mut self, iter: impl Iterator<Item = (usize, usize)>) {
795        self.dst_values.reserve_exact(self.cur_offset.as_usize());
796
797        for (start, end) in iter {
798            let value_start = self.get_value_offset(start);
799            let value_end = self.get_value_offset(end);
800            self.dst_values
801                .extend_from_slice(&self.src_values[value_start..value_end]);
802        }
803    }
804}
805
806/// `filter` implementation for byte arrays
807///
808/// Note: NULLs with a non-zero slot length in `array` will have the corresponding
809/// data copied across. This allows handling the null mask separately from the data
810fn filter_bytes<T>(array: &GenericByteArray<T>, predicate: &FilterPredicate) -> GenericByteArray<T>
811where
812    T: ByteArrayType,
813{
814    let mut filter = FilterBytes::new(predicate.count, array);
815
816    match &predicate.strategy {
817        IterationStrategy::SlicesIterator => {
818            filter.extend_offsets_slices(SlicesIterator::new(&predicate.filter), predicate.count);
819            filter.extend_slices(SlicesIterator::new(&predicate.filter))
820        }
821        IterationStrategy::Slices(slices) => {
822            filter.extend_offsets_slices(slices.iter().cloned(), predicate.count);
823            filter.extend_slices(slices.iter().cloned())
824        }
825        IterationStrategy::IndexIterator => {
826            filter.extend_offsets_idx(IndexIterator::new(&predicate.filter, predicate.count));
827            filter.extend_idx(IndexIterator::new(&predicate.filter, predicate.count))
828        }
829        IterationStrategy::Indices(indices) => {
830            filter.extend_offsets_idx(indices.iter().cloned());
831            filter.extend_idx(indices.iter().cloned())
832        }
833        IterationStrategy::All | IterationStrategy::None => unreachable!(),
834    }
835
836    // SAFETY: `dst_offsets` starts at `[0]` and only grows by the running
837    // `cur_offset`, so it is monotonically non-decreasing.
838    let offsets = unsafe { OffsetBuffer::new_unchecked(filter.dst_offsets.into()) };
839    let nulls = predicate.filter_nulls(array.nulls());
840
841    // SAFETY: `offsets` index into `dst_values` by construction, and each slot
842    // is a byte-for-byte copy from `array`, so UTF-8 validity (if any) is preserved.
843    // Length invariant: `offsets.len() - 1 == predicate.count == nulls.len()`.
844    unsafe { GenericByteArray::new_unchecked(offsets, filter.dst_values.into(), nulls) }
845}
846
847/// `filter` implementation for byte view arrays.
848fn filter_byte_view<T: ByteViewType>(
849    array: &GenericByteViewArray<T>,
850    predicate: &FilterPredicate,
851) -> GenericByteViewArray<T> {
852    let new_view_buffer = filter_native(array.views(), predicate);
853    let views = ScalarBuffer::new(new_view_buffer, 0, predicate.count);
854    let buffers = array.data_buffers().to_vec();
855    let nulls = predicate.filter_nulls(array.nulls());
856
857    // SAFETY: each view is copied unchanged from `array.views()` and `buffers`
858    // is the same buffer list, so every view still points to an in-bounds
859    // (and, for strings, UTF-8 valid) range.
860    unsafe { GenericByteViewArray::new_unchecked(views, buffers, nulls) }
861}
862
863fn filter_fixed_size_binary(
864    array: &FixedSizeBinaryArray,
865    predicate: &FilterPredicate,
866) -> FixedSizeBinaryArray {
867    let values: &[u8] = array.values();
868    let value_length = array.value_length() as usize;
869    let calculate_offset_from_index = |index: usize| index * value_length;
870    let buffer = match &predicate.strategy {
871        IterationStrategy::SlicesIterator => {
872            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
873            for (start, end) in SlicesIterator::new(&predicate.filter) {
874                buffer.extend_from_slice(
875                    &values[calculate_offset_from_index(start)..calculate_offset_from_index(end)],
876                );
877            }
878            buffer
879        }
880        IterationStrategy::Slices(slices) => {
881            let mut buffer = MutableBuffer::with_capacity(predicate.count * value_length);
882            for (start, end) in slices {
883                buffer.extend_from_slice(
884                    &values[calculate_offset_from_index(*start)..calculate_offset_from_index(*end)],
885                );
886            }
887            buffer
888        }
889        IterationStrategy::IndexIterator => {
890            let iter = IndexIterator::new(&predicate.filter, predicate.count).map(|x| {
891                &values[calculate_offset_from_index(x)..calculate_offset_from_index(x + 1)]
892            });
893
894            let mut buffer = MutableBuffer::new(predicate.count * value_length);
895            iter.for_each(|item| buffer.extend_from_slice(item));
896            buffer
897        }
898        IterationStrategy::Indices(indices) => {
899            let iter = indices.iter().map(|x| {
900                &values[calculate_offset_from_index(*x)..calculate_offset_from_index(*x + 1)]
901            });
902
903            let mut buffer = MutableBuffer::new(predicate.count * value_length);
904            iter.for_each(|item| buffer.extend_from_slice(item));
905            buffer
906        }
907        IterationStrategy::All | IterationStrategy::None => unreachable!(),
908    };
909
910    let nulls = predicate.filter_nulls(array.nulls());
911
912    FixedSizeBinaryArray::new(array.value_length(), buffer.into(), nulls)
913}
914
915/// `filter` implementation for dictionaries
916fn filter_dict<T>(array: &DictionaryArray<T>, predicate: &FilterPredicate) -> DictionaryArray<T>
917where
918    T: ArrowDictionaryKeyType,
919    T::Native: num_traits::Num,
920{
921    let builder = filter_primitive::<T>(array.keys(), predicate)
922        .into_data()
923        .into_builder()
924        .data_type(array.data_type().clone())
925        .child_data(vec![array.values().to_data()]);
926
927    // SAFETY:
928    // Keys were valid before, filtered subset is therefore still valid
929    DictionaryArray::from(unsafe { builder.build_unchecked() })
930}
931
932/// `filter` implementation for structs
933fn filter_struct(
934    array: &StructArray,
935    predicate: &FilterPredicate,
936) -> Result<StructArray, ArrowError> {
937    let columns = array
938        .columns()
939        .iter()
940        .map(|column| filter_array(column, predicate))
941        .collect::<Result<_, _>>()?;
942
943    let nulls = if let Some((null_count, nulls)) = filter_null_mask(array.nulls(), predicate) {
944        let buffer = BooleanBuffer::new(nulls, 0, predicate.count);
945
946        Some(unsafe { NullBuffer::new_unchecked(buffer, null_count) })
947    } else {
948        None
949    };
950
951    Ok(unsafe {
952        StructArray::new_unchecked_with_length(
953            array.fields().clone(),
954            columns,
955            nulls,
956            predicate.count(),
957        )
958    })
959}
960
961/// `filter` implementation for sparse unions
962fn filter_sparse_union(
963    array: &UnionArray,
964    predicate: &FilterPredicate,
965) -> Result<UnionArray, ArrowError> {
966    let DataType::Union(fields, UnionMode::Sparse) = array.data_type() else {
967        unreachable!()
968    };
969
970    let type_ids = filter_primitive(
971        &Int8Array::try_new(array.type_ids().clone(), None)?,
972        predicate,
973    );
974
975    let children = fields
976        .iter()
977        .map(|(child_type_id, _)| filter_array(array.child(child_type_id), predicate))
978        .collect::<Result<_, _>>()?;
979
980    Ok(unsafe {
981        UnionArray::new_unchecked(fields.clone(), type_ids.into_parts().1, None, children)
982    })
983}
984
985/// `filter` implementation for list views
986fn filter_list_view<OffsetType: OffsetSizeTrait>(
987    array: &GenericListViewArray<OffsetType>,
988    predicate: &FilterPredicate,
989) -> GenericListViewArray<OffsetType> {
990    let filtered_offsets = filter_native::<OffsetType>(array.offsets(), predicate);
991    let filtered_sizes = filter_native::<OffsetType>(array.sizes(), predicate);
992
993    let field = match array.data_type() {
994        DataType::ListView(field) | DataType::LargeListView(field) => field.clone(),
995        _ => unreachable!(),
996    };
997    let offsets = ScalarBuffer::new(filtered_offsets, 0, predicate.count);
998    let sizes = ScalarBuffer::new(filtered_sizes, 0, predicate.count);
999    let values = array.values().clone();
1000    let nulls = predicate.filter_nulls(array.nulls());
1001
1002    // SAFETY: each `(offset, size)` pair is copied unchanged from `array` and
1003    // indexes into the same `values` child, so every range stays in-bounds.
1004    // `field` and `values`' data type are unchanged from `array`.
1005    unsafe { GenericListViewArray::new_unchecked(field, offsets, sizes, values, nulls) }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011    use arrow_array::builder::*;
1012    use arrow_array::cast::as_run_array;
1013    use arrow_array::types::*;
1014    use rand::distr::uniform::{UniformSampler, UniformUsize};
1015    use rand::distr::{Alphanumeric, StandardUniform};
1016    use rand::prelude::*;
1017    use rand::rng;
1018
1019    macro_rules! def_temporal_test {
1020        ($test:ident, $array_type: ident, $data: expr) => {
1021            #[test]
1022            fn $test() {
1023                let a = $data;
1024                let b = BooleanArray::from(vec![true, false, true, false]);
1025                let c = filter(&a, &b).unwrap();
1026                let d = c.as_ref().as_any().downcast_ref::<$array_type>().unwrap();
1027                assert_eq!(2, d.len());
1028                assert_eq!(1, d.value(0));
1029                assert_eq!(3, d.value(1));
1030            }
1031        };
1032    }
1033
1034    def_temporal_test!(
1035        test_filter_date32,
1036        Date32Array,
1037        Date32Array::from(vec![1, 2, 3, 4])
1038    );
1039    def_temporal_test!(
1040        test_filter_date64,
1041        Date64Array,
1042        Date64Array::from(vec![1, 2, 3, 4])
1043    );
1044    def_temporal_test!(
1045        test_filter_time32_second,
1046        Time32SecondArray,
1047        Time32SecondArray::from(vec![1, 2, 3, 4])
1048    );
1049    def_temporal_test!(
1050        test_filter_time32_millisecond,
1051        Time32MillisecondArray,
1052        Time32MillisecondArray::from(vec![1, 2, 3, 4])
1053    );
1054    def_temporal_test!(
1055        test_filter_time64_microsecond,
1056        Time64MicrosecondArray,
1057        Time64MicrosecondArray::from(vec![1, 2, 3, 4])
1058    );
1059    def_temporal_test!(
1060        test_filter_time64_nanosecond,
1061        Time64NanosecondArray,
1062        Time64NanosecondArray::from(vec![1, 2, 3, 4])
1063    );
1064    def_temporal_test!(
1065        test_filter_duration_second,
1066        DurationSecondArray,
1067        DurationSecondArray::from(vec![1, 2, 3, 4])
1068    );
1069    def_temporal_test!(
1070        test_filter_duration_millisecond,
1071        DurationMillisecondArray,
1072        DurationMillisecondArray::from(vec![1, 2, 3, 4])
1073    );
1074    def_temporal_test!(
1075        test_filter_duration_microsecond,
1076        DurationMicrosecondArray,
1077        DurationMicrosecondArray::from(vec![1, 2, 3, 4])
1078    );
1079    def_temporal_test!(
1080        test_filter_duration_nanosecond,
1081        DurationNanosecondArray,
1082        DurationNanosecondArray::from(vec![1, 2, 3, 4])
1083    );
1084    def_temporal_test!(
1085        test_filter_timestamp_second,
1086        TimestampSecondArray,
1087        TimestampSecondArray::from(vec![1, 2, 3, 4])
1088    );
1089    def_temporal_test!(
1090        test_filter_timestamp_millisecond,
1091        TimestampMillisecondArray,
1092        TimestampMillisecondArray::from(vec![1, 2, 3, 4])
1093    );
1094    def_temporal_test!(
1095        test_filter_timestamp_microsecond,
1096        TimestampMicrosecondArray,
1097        TimestampMicrosecondArray::from(vec![1, 2, 3, 4])
1098    );
1099    def_temporal_test!(
1100        test_filter_timestamp_nanosecond,
1101        TimestampNanosecondArray,
1102        TimestampNanosecondArray::from(vec![1, 2, 3, 4])
1103    );
1104
1105    #[test]
1106    fn test_filter_array_slice() {
1107        let a = Int32Array::from(vec![5, 6, 7, 8, 9]).slice(1, 4);
1108        let b = BooleanArray::from(vec![true, false, false, true]);
1109        // filtering with sliced filter array is not currently supported
1110        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1111        // let b = b_slice.as_any().downcast_ref().unwrap();
1112        let c = filter(&a, &b).unwrap();
1113        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1114        assert_eq!(2, d.len());
1115        assert_eq!(6, d.value(0));
1116        assert_eq!(9, d.value(1));
1117    }
1118
1119    #[test]
1120    fn test_filter_array_low_density() {
1121        // this test exercises the all 0's branch of the filter algorithm
1122        let mut data_values = (1..=65).collect::<Vec<i32>>();
1123        let mut filter_values = (1..=65).map(|i| matches!(i % 65, 0)).collect::<Vec<bool>>();
1124        // set up two more values after the batch
1125        data_values.extend_from_slice(&[66, 67]);
1126        filter_values.extend_from_slice(&[false, true]);
1127        let a = Int32Array::from(data_values);
1128        let b = BooleanArray::from(filter_values);
1129        let c = filter(&a, &b).unwrap();
1130        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1131        assert_eq!(2, d.len());
1132        assert_eq!(65, d.value(0));
1133        assert_eq!(67, d.value(1));
1134    }
1135
1136    #[test]
1137    fn test_filter_array_high_density() {
1138        // this test exercises the all 1's branch of the filter algorithm
1139        let mut data_values = (1..=65).map(Some).collect::<Vec<_>>();
1140        let mut filter_values = (1..=65)
1141            .map(|i| !matches!(i % 65, 0))
1142            .collect::<Vec<bool>>();
1143        // set second data value to null
1144        data_values[1] = None;
1145        // set up two more values after the batch
1146        data_values.extend_from_slice(&[Some(66), None, Some(67), None]);
1147        filter_values.extend_from_slice(&[false, true, true, true]);
1148        let a = Int32Array::from(data_values);
1149        let b = BooleanArray::from(filter_values);
1150        let c = filter(&a, &b).unwrap();
1151        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1152        assert_eq!(67, d.len());
1153        assert_eq!(3, d.null_count());
1154        assert_eq!(1, d.value(0));
1155        assert!(d.is_null(1));
1156        assert_eq!(64, d.value(63));
1157        assert!(d.is_null(64));
1158        assert_eq!(67, d.value(65));
1159    }
1160
1161    #[test]
1162    fn test_filter_string_array_simple() {
1163        let a = StringArray::from(vec!["hello", " ", "world", "!"]);
1164        let b = BooleanArray::from(vec![true, false, true, false]);
1165        let c = filter(&a, &b).unwrap();
1166        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1167        assert_eq!(2, d.len());
1168        assert_eq!("hello", d.value(0));
1169        assert_eq!("world", d.value(1));
1170    }
1171
1172    #[test]
1173    fn test_filter_primitive_array_with_null() {
1174        let a = Int32Array::from(vec![Some(5), None]);
1175        let b = BooleanArray::from(vec![false, true]);
1176        let c = filter(&a, &b).unwrap();
1177        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1178        assert_eq!(1, d.len());
1179        assert!(d.is_null(0));
1180    }
1181
1182    #[test]
1183    fn test_filter_string_array_with_null() {
1184        let a = StringArray::from(vec![Some("hello"), None, Some("world"), None]);
1185        let b = BooleanArray::from(vec![true, false, false, true]);
1186        let c = filter(&a, &b).unwrap();
1187        let d = c.as_ref().as_any().downcast_ref::<StringArray>().unwrap();
1188        assert_eq!(2, d.len());
1189        assert_eq!("hello", d.value(0));
1190        assert!(!d.is_null(0));
1191        assert!(d.is_null(1));
1192    }
1193
1194    #[test]
1195    fn test_filter_binary_array_with_null() {
1196        let data: Vec<Option<&[u8]>> = vec![Some(b"hello"), None, Some(b"world"), None];
1197        let a = BinaryArray::from(data);
1198        let b = BooleanArray::from(vec![true, false, false, true]);
1199        let c = filter(&a, &b).unwrap();
1200        let d = c.as_ref().as_any().downcast_ref::<BinaryArray>().unwrap();
1201        assert_eq!(2, d.len());
1202        assert_eq!(b"hello", d.value(0));
1203        assert!(!d.is_null(0));
1204        assert!(d.is_null(1));
1205    }
1206
1207    fn _test_filter_byte_view<T>()
1208    where
1209        T: ByteViewType,
1210        str: AsRef<T::Native>,
1211        T::Native: PartialEq,
1212    {
1213        let array = {
1214            // ["hello", "world", null, "large payload over 12 bytes", "lulu"]
1215            let mut builder = GenericByteViewBuilder::<T>::new();
1216            builder.append_value("hello");
1217            builder.append_value("world");
1218            builder.append_null();
1219            builder.append_value("large payload over 12 bytes");
1220            builder.append_value("lulu");
1221            builder.finish()
1222        };
1223
1224        {
1225            let predicate = BooleanArray::from(vec![true, false, true, true, false]);
1226            let actual = filter(&array, &predicate).unwrap();
1227
1228            assert_eq!(actual.len(), 3);
1229
1230            let expected = {
1231                // ["hello", null, "large payload over 12 bytes"]
1232                let mut builder = GenericByteViewBuilder::<T>::new();
1233                builder.append_value("hello");
1234                builder.append_null();
1235                builder.append_value("large payload over 12 bytes");
1236                builder.finish()
1237            };
1238
1239            assert_eq!(actual.as_ref(), &expected);
1240        }
1241
1242        {
1243            let predicate = BooleanArray::from(vec![true, false, false, false, true]);
1244            let actual = filter(&array, &predicate).unwrap();
1245
1246            assert_eq!(actual.len(), 2);
1247
1248            let expected = {
1249                // ["hello", "lulu"]
1250                let mut builder = GenericByteViewBuilder::<T>::new();
1251                builder.append_value("hello");
1252                builder.append_value("lulu");
1253                builder.finish()
1254            };
1255
1256            assert_eq!(actual.as_ref(), &expected);
1257        }
1258    }
1259
1260    #[test]
1261    fn test_filter_string_view() {
1262        _test_filter_byte_view::<StringViewType>()
1263    }
1264
1265    #[test]
1266    fn test_filter_binary_view() {
1267        _test_filter_byte_view::<BinaryViewType>()
1268    }
1269
1270    #[test]
1271    fn test_filter_fixed_binary() {
1272        let v1 = [1_u8, 2];
1273        let v2 = [3_u8, 4];
1274        let v3 = [5_u8, 6];
1275        let v = vec![&v1, &v2, &v3];
1276        let a = FixedSizeBinaryArray::try_from(v).unwrap();
1277        let b = BooleanArray::from(vec![true, false, true]);
1278        let c = filter(&a, &b).unwrap();
1279        let d = c
1280            .as_ref()
1281            .as_any()
1282            .downcast_ref::<FixedSizeBinaryArray>()
1283            .unwrap();
1284        assert_eq!(d.len(), 2);
1285        assert_eq!(d.value(0), &v1);
1286        assert_eq!(d.value(1), &v3);
1287        let c2 = FilterBuilder::new(&b)
1288            .optimize()
1289            .build()
1290            .filter(&a)
1291            .unwrap();
1292        let d2 = c2
1293            .as_ref()
1294            .as_any()
1295            .downcast_ref::<FixedSizeBinaryArray>()
1296            .unwrap();
1297        assert_eq!(d, d2);
1298
1299        let b = BooleanArray::from(vec![false, false, false]);
1300        let c = filter(&a, &b).unwrap();
1301        let d = c
1302            .as_ref()
1303            .as_any()
1304            .downcast_ref::<FixedSizeBinaryArray>()
1305            .unwrap();
1306        assert_eq!(d.len(), 0);
1307
1308        let b = BooleanArray::from(vec![true, true, true]);
1309        let c = filter(&a, &b).unwrap();
1310        let d = c
1311            .as_ref()
1312            .as_any()
1313            .downcast_ref::<FixedSizeBinaryArray>()
1314            .unwrap();
1315        assert_eq!(d.len(), 3);
1316        assert_eq!(d.value(0), &v1);
1317        assert_eq!(d.value(1), &v2);
1318        assert_eq!(d.value(2), &v3);
1319
1320        let b = BooleanArray::from(vec![false, false, true]);
1321        let c = filter(&a, &b).unwrap();
1322        let d = c
1323            .as_ref()
1324            .as_any()
1325            .downcast_ref::<FixedSizeBinaryArray>()
1326            .unwrap();
1327        assert_eq!(d.len(), 1);
1328        assert_eq!(d.value(0), &v3);
1329        let c2 = FilterBuilder::new(&b)
1330            .optimize()
1331            .build()
1332            .filter(&a)
1333            .unwrap();
1334        let d2 = c2
1335            .as_ref()
1336            .as_any()
1337            .downcast_ref::<FixedSizeBinaryArray>()
1338            .unwrap();
1339        assert_eq!(d, d2);
1340    }
1341
1342    #[test]
1343    fn test_filter_array_slice_with_null() {
1344        let a = Int32Array::from(vec![Some(5), None, Some(7), Some(8), Some(9)]).slice(1, 4);
1345        let b = BooleanArray::from(vec![true, false, false, true]);
1346        // filtering with sliced filter array is not currently supported
1347        // let b_slice = BooleanArray::from(vec![true, false, false, true, false]).slice(1, 4);
1348        // let b = b_slice.as_any().downcast_ref().unwrap();
1349        let c = filter(&a, &b).unwrap();
1350        let d = c.as_ref().as_any().downcast_ref::<Int32Array>().unwrap();
1351        assert_eq!(2, d.len());
1352        assert!(d.is_null(0));
1353        assert!(!d.is_null(1));
1354        assert_eq!(9, d.value(1));
1355    }
1356
1357    #[test]
1358    fn test_filter_run_end_encoding_array() {
1359        let run_ends = Int64Array::from(vec![2, 3, 8]);
1360        let values = Int64Array::from(vec![7, -2, 9]);
1361        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1362        let b = BooleanArray::from(vec![true, false, true, false, true, false, true, false]);
1363        let c = filter(&a, &b).unwrap();
1364        let actual: &RunArray<Int64Type> = as_run_array(&c);
1365        assert_eq!(4, actual.len());
1366
1367        let expected = RunArray::try_new(
1368            &Int64Array::from(vec![1, 2, 4]),
1369            &Int64Array::from(vec![7, -2, 9]),
1370        )
1371        .expect("Failed to make expected RunArray test is broken");
1372
1373        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1374        assert_eq!(actual.values(), expected.values())
1375    }
1376
1377    #[test]
1378    fn test_filter_run_end_encoding_array_sliced() {
1379        let run_ends = Int64Array::from(vec![2, 3, 8]);
1380        let values = Int64Array::from(vec![7, -2, 9]);
1381        let a = RunArray::try_new(&run_ends, &values).unwrap(); // [7, 7, -2, 9, 9, 9, 9, 9]
1382        let a = a.slice(2, 3); // [-2, 9, 9]
1383        let b = BooleanArray::from(vec![true, false, true]);
1384        let result = filter(&a, &b).unwrap();
1385
1386        let result = result.as_run::<Int64Type>();
1387        let result = result.downcast::<Int64Array>().unwrap();
1388
1389        let expected = vec![-2, 9];
1390        let actual = result.into_iter().flatten().collect::<Vec<_>>();
1391        assert_eq!(expected, actual);
1392    }
1393
1394    #[test]
1395    fn test_filter_run_end_encoding_array_remove_value() {
1396        let run_ends = Int32Array::from(vec![2, 3, 8, 10]);
1397        let values = Int32Array::from(vec![7, -2, 9, -8]);
1398        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1399        let b = BooleanArray::from(vec![
1400            false, true, false, false, true, false, true, false, false, false,
1401        ]);
1402        let c = filter(&a, &b).unwrap();
1403        let actual: &RunArray<Int32Type> = as_run_array(&c);
1404        assert_eq!(3, actual.len());
1405
1406        let expected =
1407            RunArray::try_new(&Int32Array::from(vec![1, 3]), &Int32Array::from(vec![7, 9]))
1408                .expect("Failed to make expected RunArray test is broken");
1409
1410        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1411        assert_eq!(actual.values(), expected.values())
1412    }
1413
1414    #[test]
1415    fn test_filter_run_end_encoding_array_remove_all_but_one() {
1416        let run_ends = Int16Array::from(vec![2, 3, 8, 10]);
1417        let values = Int16Array::from(vec![7, -2, 9, -8]);
1418        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1419        let b = BooleanArray::from(vec![
1420            false, false, false, false, false, false, true, false, false, false,
1421        ]);
1422        let c = filter(&a, &b).unwrap();
1423        let actual: &RunArray<Int16Type> = as_run_array(&c);
1424        assert_eq!(1, actual.len());
1425
1426        let expected = RunArray::try_new(&Int16Array::from(vec![1]), &Int16Array::from(vec![9]))
1427            .expect("Failed to make expected RunArray test is broken");
1428
1429        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1430        assert_eq!(actual.values(), expected.values())
1431    }
1432
1433    #[test]
1434    fn test_filter_run_end_encoding_array_empty() {
1435        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1436        let values = Int64Array::from(vec![7, -2, 9, -8]);
1437        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1438        let b = BooleanArray::from(vec![
1439            false, false, false, false, false, false, false, false, false, false,
1440        ]);
1441        let c = filter(&a, &b).unwrap();
1442        let actual: &RunArray<Int64Type> = as_run_array(&c);
1443        assert_eq!(0, actual.len());
1444    }
1445
1446    #[test]
1447    fn test_filter_run_end_encoding_array_max_value_gt_predicate_len() {
1448        let run_ends = Int64Array::from(vec![2, 3, 8, 10]);
1449        let values = Int64Array::from(vec![7, -2, 9, -8]);
1450        let a = RunArray::try_new(&run_ends, &values).expect("Failed to create RunArray");
1451        let b = BooleanArray::from(vec![false, true, true]);
1452        let c = filter(&a, &b).unwrap();
1453        let actual: &RunArray<Int64Type> = as_run_array(&c);
1454        assert_eq!(2, actual.len());
1455
1456        let expected = RunArray::try_new(
1457            &Int64Array::from(vec![1, 2]),
1458            &Int64Array::from(vec![7, -2]),
1459        )
1460        .expect("Failed to make expected RunArray test is broken");
1461
1462        assert_eq!(&actual.run_ends().values(), &expected.run_ends().values());
1463        assert_eq!(actual.values(), expected.values())
1464    }
1465
1466    #[test]
1467    fn test_filter_dictionary_array() {
1468        let values = [Some("hello"), None, Some("world"), Some("!")];
1469        let a: Int8DictionaryArray = values.iter().copied().collect();
1470        let b = BooleanArray::from(vec![false, true, true, false]);
1471        let c = filter(&a, &b).unwrap();
1472        let d = c
1473            .as_ref()
1474            .as_any()
1475            .downcast_ref::<Int8DictionaryArray>()
1476            .unwrap();
1477        let value_array = d.values();
1478        let values = value_array.as_any().downcast_ref::<StringArray>().unwrap();
1479        // values are cloned in the filtered dictionary array
1480        assert_eq!(3, values.len());
1481        // but keys are filtered
1482        assert_eq!(2, d.len());
1483        assert!(d.is_null(0));
1484        assert_eq!("world", values.value(d.keys().value(1) as usize));
1485    }
1486
1487    #[test]
1488    fn test_filter_list_array() {
1489        let field = Arc::new(Field::new_list_field(DataType::Int32, false));
1490        let offsets = OffsetBuffer::new(vec![0i64, 3, 6, 8, 8].into());
1491        let value_array = Arc::new(Int32Array::from_iter_values(0..8));
1492        let nulls = Some(NullBuffer::from(vec![true, true, true, false]));
1493        //  a = [[0, 1, 2], [3, 4, 5], [6, 7], null]
1494        let a = LargeListArray::new(field.clone(), offsets, value_array, nulls);
1495        let b = BooleanArray::from(vec![false, true, false, true]);
1496        let result = filter(&a, &b).unwrap();
1497
1498        // expected: [[3, 4, 5], null]
1499        let offsets = OffsetBuffer::new(vec![0i64, 3, 3].into());
1500        let value_array = Arc::new(Int32Array::from_iter_values([3, 4, 5]));
1501        let nulls = Some(NullBuffer::from(vec![true, false]));
1502        let expected: ArrayRef = Arc::new(LargeListArray::new(field, offsets, value_array, nulls));
1503
1504        assert_eq!(&expected, &result);
1505    }
1506
1507    fn test_case_filter_list_view<T: OffsetSizeTrait>() {
1508        // [[1, 2], null, [], [3,4]]
1509        let mut list_array = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1510        list_array.append_value([Some(1), Some(2)]);
1511        list_array.append_null();
1512        list_array.append_value([]);
1513        list_array.append_value([Some(3), Some(4)]);
1514
1515        let list_array = list_array.finish();
1516        let predicate = BooleanArray::from_iter([true, false, true, false]);
1517
1518        // Filter result: [[1, 2], []]
1519        let filtered = filter(&list_array, &predicate)
1520            .unwrap()
1521            .as_list_view::<T>()
1522            .clone();
1523
1524        let mut expected =
1525            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(5), 3);
1526        expected.append_value([Some(1), Some(2)]);
1527        expected.append_value([]);
1528        let expected = expected.finish();
1529
1530        assert_eq!(&filtered, &expected);
1531    }
1532
1533    fn test_case_filter_sliced_list_view<T: OffsetSizeTrait>() {
1534        // [[1, 2], null, [], [3,4]]
1535        let mut list_array =
1536            GenericListViewBuilder::<T, _>::with_capacity(Int32Builder::with_capacity(6), 4);
1537        list_array.append_value([Some(1), Some(2)]);
1538        list_array.append_null();
1539        list_array.append_value([]);
1540        list_array.append_value([Some(3), Some(4)]);
1541
1542        let list_array = list_array.finish();
1543
1544        // Sliced: [null, [], [3, 4]]
1545        let sliced = list_array.slice(1, 3);
1546        let predicate = BooleanArray::from_iter([false, false, true]);
1547
1548        // Filter result: [[1, 2], []]
1549        let filtered = filter(&sliced, &predicate)
1550            .unwrap()
1551            .as_list_view::<T>()
1552            .clone();
1553
1554        let mut expected = GenericListViewBuilder::<T, _>::new(Int32Builder::new());
1555        expected.append_value([Some(3), Some(4)]);
1556        let expected = expected.finish();
1557
1558        assert_eq!(&filtered, &expected);
1559    }
1560
1561    #[test]
1562    fn test_filter_list_view_array() {
1563        test_case_filter_list_view::<i32>();
1564        test_case_filter_list_view::<i64>();
1565
1566        test_case_filter_sliced_list_view::<i32>();
1567        test_case_filter_sliced_list_view::<i64>();
1568    }
1569
1570    #[test]
1571    fn test_slice_iterator_bits() {
1572        let filter_values = (0..64).map(|i| i == 1).collect::<Vec<bool>>();
1573        let filter = BooleanArray::from(filter_values);
1574        let filter_count = filter_count(&filter);
1575
1576        let iter = SlicesIterator::new(&filter);
1577        let chunks = iter.collect::<Vec<_>>();
1578
1579        assert_eq!(chunks, vec![(1, 2)]);
1580        assert_eq!(filter_count, 1);
1581    }
1582
1583    #[test]
1584    fn test_slice_iterator_bits1() {
1585        let filter_values = (0..64).map(|i| i != 1).collect::<Vec<bool>>();
1586        let filter = BooleanArray::from(filter_values);
1587        let filter_count = filter_count(&filter);
1588
1589        let iter = SlicesIterator::new(&filter);
1590        let chunks = iter.collect::<Vec<_>>();
1591
1592        assert_eq!(chunks, vec![(0, 1), (2, 64)]);
1593        assert_eq!(filter_count, 64 - 1);
1594    }
1595
1596    #[test]
1597    fn test_slice_iterator_chunk_and_bits() {
1598        let filter_values = (0..130).map(|i| i % 62 != 0).collect::<Vec<bool>>();
1599        let filter = BooleanArray::from(filter_values);
1600        let filter_count = filter_count(&filter);
1601
1602        let iter = SlicesIterator::new(&filter);
1603        let chunks = iter.collect::<Vec<_>>();
1604
1605        assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]);
1606        assert_eq!(filter_count, 61 + 61 + 5);
1607    }
1608
1609    #[test]
1610    fn test_null_mask() {
1611        let a = Int64Array::from(vec![Some(1), Some(2), None]);
1612
1613        let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]);
1614        let out = filter(&a, &mask1).unwrap();
1615        assert_eq!(out.as_ref(), &a.slice(0, 2));
1616    }
1617
1618    #[test]
1619    fn test_filter_record_batch_no_columns() {
1620        let pred = BooleanArray::from(vec![Some(true), Some(true), None]);
1621        let options = RecordBatchOptions::default().with_row_count(Some(100));
1622        let record_batch =
1623            RecordBatch::try_new_with_options(Arc::new(Schema::empty()), vec![], &options).unwrap();
1624        let out = filter_record_batch(&record_batch, &pred).unwrap();
1625
1626        assert_eq!(out.num_rows(), 2);
1627    }
1628
1629    #[test]
1630    fn test_fast_path() {
1631        let a: PrimitiveArray<Int64Type> = PrimitiveArray::from(vec![Some(1), Some(2), None]);
1632
1633        // all true
1634        let mask = BooleanArray::from(vec![true, true, true]);
1635        let out = filter(&a, &mask).unwrap();
1636        let b = out
1637            .as_any()
1638            .downcast_ref::<PrimitiveArray<Int64Type>>()
1639            .unwrap();
1640        assert_eq!(&a, b);
1641
1642        // all false
1643        let mask = BooleanArray::from(vec![false, false, false]);
1644        let out = filter(&a, &mask).unwrap();
1645        assert_eq!(out.len(), 0);
1646        assert_eq!(out.data_type(), &DataType::Int64);
1647    }
1648
1649    #[test]
1650    fn test_slices() {
1651        // takes up 2 u64s
1652        let bools = std::iter::repeat_n(true, 10)
1653            .chain(std::iter::repeat_n(false, 30))
1654            .chain(std::iter::repeat_n(true, 20))
1655            .chain(std::iter::repeat_n(false, 17))
1656            .chain(std::iter::repeat_n(true, 4));
1657
1658        let bool_array: BooleanArray = bools.map(Some).collect();
1659
1660        let slices: Vec<_> = SlicesIterator::new(&bool_array).collect();
1661        let expected = vec![(0, 10), (40, 60), (77, 81)];
1662        assert_eq!(slices, expected);
1663
1664        // slice with offset and truncated len
1665        let len = bool_array.len();
1666        let sliced_array = bool_array.slice(7, len - 10);
1667        let sliced_array = sliced_array
1668            .as_any()
1669            .downcast_ref::<BooleanArray>()
1670            .unwrap();
1671        let slices: Vec<_> = SlicesIterator::new(sliced_array).collect();
1672        let expected = vec![(0, 3), (33, 53), (70, 71)];
1673        assert_eq!(slices, expected);
1674    }
1675
1676    fn test_slices_fuzz(mask_len: usize, offset: usize, truncate: usize) {
1677        let mut rng = rng();
1678
1679        let bools: Vec<bool> = std::iter::from_fn(|| Some(rng.random()))
1680            .take(mask_len)
1681            .collect();
1682
1683        let buffer = Buffer::from_iter(bools.iter().cloned());
1684
1685        let truncated_length = mask_len - offset - truncate;
1686
1687        let filter = BooleanArray::new(BooleanBuffer::new(buffer, offset, truncated_length), None);
1688
1689        let slice_bits: Vec<_> = SlicesIterator::new(&filter)
1690            .flat_map(|(start, end)| start..end)
1691            .collect();
1692
1693        let count = filter_count(&filter);
1694        let index_bits: Vec<_> = IndexIterator::new(&filter, count).collect();
1695
1696        let expected_bits: Vec<_> = bools
1697            .iter()
1698            .skip(offset)
1699            .take(truncated_length)
1700            .enumerate()
1701            .flat_map(|(idx, v)| v.then(|| idx))
1702            .collect();
1703
1704        assert_eq!(slice_bits, expected_bits);
1705        assert_eq!(index_bits, expected_bits);
1706    }
1707
1708    #[test]
1709    #[cfg_attr(miri, ignore)]
1710    fn fuzz_test_slices_iterator() {
1711        let mut rng = rng();
1712
1713        let uusize = UniformUsize::new(usize::MIN, usize::MAX).unwrap();
1714        for _ in 0..100 {
1715            let mask_len = rng.random_range(0..1024);
1716            let max_offset = 64.min(mask_len);
1717            let offset = uusize.sample(&mut rng).checked_rem(max_offset).unwrap_or(0);
1718
1719            let max_truncate = 128.min(mask_len - offset);
1720            let truncate = uusize
1721                .sample(&mut rng)
1722                .checked_rem(max_truncate)
1723                .unwrap_or(0);
1724
1725            test_slices_fuzz(mask_len, offset, truncate);
1726        }
1727
1728        test_slices_fuzz(64, 0, 0);
1729        test_slices_fuzz(64, 8, 0);
1730        test_slices_fuzz(64, 8, 8);
1731        test_slices_fuzz(32, 8, 8);
1732        test_slices_fuzz(32, 5, 9);
1733    }
1734
1735    /// Filters `values` by `predicate` using standard rust iterators
1736    fn filter_rust<T>(values: impl IntoIterator<Item = T>, predicate: &[bool]) -> Vec<T> {
1737        values
1738            .into_iter()
1739            .zip(predicate)
1740            .filter(|(_, x)| **x)
1741            .map(|(a, _)| a)
1742            .collect()
1743    }
1744
1745    /// Generates an array of length `len` with `valid_percent` non-null values
1746    fn gen_primitive<T>(len: usize, valid_percent: f64) -> Vec<Option<T>>
1747    where
1748        StandardUniform: Distribution<T>,
1749    {
1750        let mut rng = rng();
1751        (0..len)
1752            .map(|_| rng.random_bool(valid_percent).then(|| rng.random()))
1753            .collect()
1754    }
1755
1756    /// Generates an array of length `len` with `valid_percent` non-null values
1757    fn gen_strings(
1758        len: usize,
1759        valid_percent: f64,
1760        str_len_range: std::ops::Range<usize>,
1761    ) -> Vec<Option<String>> {
1762        let mut rng = rng();
1763        (0..len)
1764            .map(|_| {
1765                rng.random_bool(valid_percent).then(|| {
1766                    let len = rng.random_range(str_len_range.clone());
1767                    (0..len)
1768                        .map(|_| char::from(rng.sample(Alphanumeric)))
1769                        .collect()
1770                })
1771            })
1772            .collect()
1773    }
1774
1775    /// Returns an iterator that calls `Option::as_deref` on each item
1776    fn as_deref<T: std::ops::Deref>(src: &[Option<T>]) -> impl Iterator<Item = Option<&T::Target>> {
1777        src.iter().map(|x| x.as_deref())
1778    }
1779
1780    #[test]
1781    #[cfg_attr(miri, ignore)]
1782    fn fuzz_filter() {
1783        let mut rng = rng();
1784
1785        for i in 0..100 {
1786            let filter_percent = match i {
1787                0..=4 => 1.,
1788                5..=10 => 0.,
1789                _ => rng.random_range(0.0..1.0),
1790            };
1791
1792            let valid_percent = rng.random_range(0.0..1.0);
1793
1794            let array_len = rng.random_range(32..256);
1795            let array_offset = rng.random_range(0..10);
1796
1797            // Construct a predicate
1798            let filter_offset = rng.random_range(0..10);
1799            let filter_truncate = rng.random_range(0..10);
1800            let bools: Vec<_> = std::iter::from_fn(|| Some(rng.random_bool(filter_percent)))
1801                .take(array_len + filter_offset - filter_truncate)
1802                .collect();
1803
1804            let predicate = BooleanArray::from_iter(bools.iter().cloned().map(Some));
1805
1806            // Offset predicate
1807            let predicate = predicate.slice(filter_offset, array_len - filter_truncate);
1808            let predicate = predicate.as_any().downcast_ref::<BooleanArray>().unwrap();
1809            let bools = &bools[filter_offset..];
1810
1811            // Test i32
1812            let values = gen_primitive(array_len + array_offset, valid_percent);
1813            let src = Int32Array::from_iter(values.iter().cloned());
1814
1815            let src = src.slice(array_offset, array_len);
1816            let src = src.as_any().downcast_ref::<Int32Array>().unwrap();
1817            let values = &values[array_offset..];
1818
1819            let filtered = filter(src, predicate).unwrap();
1820            let array = filtered.as_any().downcast_ref::<Int32Array>().unwrap();
1821            let actual: Vec<_> = array.iter().collect();
1822
1823            assert_eq!(actual, filter_rust(values.iter().cloned(), bools));
1824
1825            // Test string
1826            let strings = gen_strings(array_len + array_offset, valid_percent, 0..20);
1827            let src = StringArray::from_iter(as_deref(&strings));
1828
1829            let src = src.slice(array_offset, array_len);
1830            let src = src.as_any().downcast_ref::<StringArray>().unwrap();
1831
1832            let filtered = filter(src, predicate).unwrap();
1833            let array = filtered.as_any().downcast_ref::<StringArray>().unwrap();
1834            let actual: Vec<_> = array.iter().collect();
1835
1836            let expected_strings = filter_rust(as_deref(&strings[array_offset..]), bools);
1837            assert_eq!(actual, expected_strings);
1838
1839            // Test string dictionary
1840            let src = DictionaryArray::<Int32Type>::from_iter(as_deref(&strings));
1841
1842            let src = src.slice(array_offset, array_len);
1843            let src = src
1844                .as_any()
1845                .downcast_ref::<DictionaryArray<Int32Type>>()
1846                .unwrap();
1847
1848            let filtered = filter(src, predicate).unwrap();
1849
1850            let array = filtered
1851                .as_any()
1852                .downcast_ref::<DictionaryArray<Int32Type>>()
1853                .unwrap();
1854
1855            let values = array
1856                .values()
1857                .as_any()
1858                .downcast_ref::<StringArray>()
1859                .unwrap();
1860
1861            let actual: Vec<_> = array
1862                .keys()
1863                .iter()
1864                .map(|key| key.map(|key| values.value(key as usize)))
1865                .collect();
1866
1867            assert_eq!(actual, expected_strings);
1868        }
1869    }
1870
1871    #[test]
1872    fn test_filter_map() {
1873        let mut builder =
1874            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(4));
1875        // [{"key1": 1}, {"key2": 2, "key3": 3}, null, {"key1": 1}
1876        builder.keys().append_value("key1");
1877        builder.values().append_value(1);
1878        builder.append(true).unwrap();
1879        builder.keys().append_value("key2");
1880        builder.keys().append_value("key3");
1881        builder.values().append_value(2);
1882        builder.values().append_value(3);
1883        builder.append(true).unwrap();
1884        builder.append(false).unwrap();
1885        builder.keys().append_value("key1");
1886        builder.values().append_value(1);
1887        builder.append(true).unwrap();
1888        let maparray = Arc::new(builder.finish()) as ArrayRef;
1889
1890        let indices = vec![Some(true), Some(false), Some(false), Some(true)]
1891            .into_iter()
1892            .collect::<BooleanArray>();
1893        let got = filter(&maparray, &indices).unwrap();
1894
1895        let mut builder =
1896            MapBuilder::new(None, StringBuilder::new(), Int64Builder::with_capacity(2));
1897        builder.keys().append_value("key1");
1898        builder.values().append_value(1);
1899        builder.append(true).unwrap();
1900        builder.keys().append_value("key1");
1901        builder.values().append_value(1);
1902        builder.append(true).unwrap();
1903        let expected = Arc::new(builder.finish()) as ArrayRef;
1904
1905        assert_eq!(&expected, &got);
1906    }
1907
1908    #[test]
1909    fn test_filter_fixed_size_list_arrays() {
1910        let field = Arc::new(Field::new_list_field(DataType::Int32, false));
1911        let value_array = Arc::new(Int32Array::from_iter_values(0..9));
1912        let array = FixedSizeListArray::new(field, 3, value_array, None);
1913
1914        let filter_array = BooleanArray::from(vec![true, false, false]);
1915
1916        let c = filter(&array, &filter_array).unwrap();
1917        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1918
1919        assert_eq!(filtered.len(), 1);
1920
1921        let list = filtered.value(0);
1922        assert_eq!(
1923            &[0, 1, 2],
1924            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1925        );
1926
1927        let filter_array = BooleanArray::from(vec![true, false, true]);
1928
1929        let c = filter(&array, &filter_array).unwrap();
1930        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1931
1932        assert_eq!(filtered.len(), 2);
1933
1934        let list = filtered.value(0);
1935        assert_eq!(
1936            &[0, 1, 2],
1937            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1938        );
1939        let list = filtered.value(1);
1940        assert_eq!(
1941            &[6, 7, 8],
1942            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1943        );
1944    }
1945
1946    #[test]
1947    fn test_filter_fixed_size_list_arrays_with_null() {
1948        let field = Arc::new(Field::new_list_field(DataType::Int32, false));
1949        let value_array = Arc::new(Int32Array::from_iter_values(0..10));
1950        let nulls = Some(NullBuffer::from(vec![true, false, false, true, true]));
1951        let array = FixedSizeListArray::new(field, 2, value_array, nulls);
1952
1953        let filter_array = BooleanArray::from(vec![true, true, false, true, false]);
1954
1955        let c = filter(&array, &filter_array).unwrap();
1956        let filtered = c.as_any().downcast_ref::<FixedSizeListArray>().unwrap();
1957
1958        assert_eq!(filtered.len(), 3);
1959
1960        let list = filtered.value(0);
1961        assert_eq!(
1962            &[0, 1],
1963            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1964        );
1965        assert!(filtered.is_null(1));
1966        let list = filtered.value(2);
1967        assert_eq!(
1968            &[6, 7],
1969            list.as_any().downcast_ref::<Int32Array>().unwrap().values()
1970        );
1971    }
1972
1973    fn test_filter_union_array(array: UnionArray) {
1974        let filter_array = BooleanArray::from(vec![true, false, false]);
1975        let c = filter(&array, &filter_array).unwrap();
1976        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1977
1978        let mut builder = UnionBuilder::new_dense();
1979        builder.append::<Int32Type>("A", 1).unwrap();
1980        let expected_array = builder.build().unwrap();
1981
1982        compare_union_arrays(filtered, &expected_array);
1983
1984        let filter_array = BooleanArray::from(vec![true, false, true]);
1985        let c = filter(&array, &filter_array).unwrap();
1986        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1987
1988        let mut builder = UnionBuilder::new_dense();
1989        builder.append::<Int32Type>("A", 1).unwrap();
1990        builder.append::<Int32Type>("A", 34).unwrap();
1991        let expected_array = builder.build().unwrap();
1992
1993        compare_union_arrays(filtered, &expected_array);
1994
1995        let filter_array = BooleanArray::from(vec![true, true, false]);
1996        let c = filter(&array, &filter_array).unwrap();
1997        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
1998
1999        let mut builder = UnionBuilder::new_dense();
2000        builder.append::<Int32Type>("A", 1).unwrap();
2001        builder.append::<Float64Type>("B", 3.2).unwrap();
2002        let expected_array = builder.build().unwrap();
2003
2004        compare_union_arrays(filtered, &expected_array);
2005    }
2006
2007    #[test]
2008    fn test_filter_union_array_dense() {
2009        let mut builder = UnionBuilder::new_dense();
2010        builder.append::<Int32Type>("A", 1).unwrap();
2011        builder.append::<Float64Type>("B", 3.2).unwrap();
2012        builder.append::<Int32Type>("A", 34).unwrap();
2013        let array = builder.build().unwrap();
2014
2015        test_filter_union_array(array);
2016    }
2017
2018    #[test]
2019    fn test_filter_run_union_array_dense() {
2020        let mut builder = UnionBuilder::new_dense();
2021        builder.append::<Int32Type>("A", 1).unwrap();
2022        builder.append::<Int32Type>("A", 3).unwrap();
2023        builder.append::<Int32Type>("A", 34).unwrap();
2024        let array = builder.build().unwrap();
2025
2026        let filter_array = BooleanArray::from(vec![true, true, false]);
2027        let c = filter(&array, &filter_array).unwrap();
2028        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2029
2030        let mut builder = UnionBuilder::new_dense();
2031        builder.append::<Int32Type>("A", 1).unwrap();
2032        builder.append::<Int32Type>("A", 3).unwrap();
2033        let expected = builder.build().unwrap();
2034
2035        assert_eq!(filtered.to_data(), expected.to_data());
2036    }
2037
2038    #[test]
2039    fn test_filter_union_array_dense_with_nulls() {
2040        let mut builder = UnionBuilder::new_dense();
2041        builder.append::<Int32Type>("A", 1).unwrap();
2042        builder.append::<Float64Type>("B", 3.2).unwrap();
2043        builder.append_null::<Float64Type>("B").unwrap();
2044        builder.append::<Int32Type>("A", 34).unwrap();
2045        let array = builder.build().unwrap();
2046
2047        let filter_array = BooleanArray::from(vec![true, true, false, false]);
2048        let c = filter(&array, &filter_array).unwrap();
2049        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2050
2051        let mut builder = UnionBuilder::new_dense();
2052        builder.append::<Int32Type>("A", 1).unwrap();
2053        builder.append::<Float64Type>("B", 3.2).unwrap();
2054        let expected_array = builder.build().unwrap();
2055
2056        compare_union_arrays(filtered, &expected_array);
2057
2058        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2059        let c = filter(&array, &filter_array).unwrap();
2060        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2061
2062        let mut builder = UnionBuilder::new_dense();
2063        builder.append::<Int32Type>("A", 1).unwrap();
2064        builder.append_null::<Float64Type>("B").unwrap();
2065        let expected_array = builder.build().unwrap();
2066
2067        compare_union_arrays(filtered, &expected_array);
2068    }
2069
2070    #[test]
2071    fn test_filter_union_array_sparse() {
2072        let mut builder = UnionBuilder::new_sparse();
2073        builder.append::<Int32Type>("A", 1).unwrap();
2074        builder.append::<Float64Type>("B", 3.2).unwrap();
2075        builder.append::<Int32Type>("A", 34).unwrap();
2076        let array = builder.build().unwrap();
2077
2078        test_filter_union_array(array);
2079    }
2080
2081    #[test]
2082    fn test_filter_union_array_sparse_with_nulls() {
2083        let mut builder = UnionBuilder::new_sparse();
2084        builder.append::<Int32Type>("A", 1).unwrap();
2085        builder.append::<Float64Type>("B", 3.2).unwrap();
2086        builder.append_null::<Float64Type>("B").unwrap();
2087        builder.append::<Int32Type>("A", 34).unwrap();
2088        let array = builder.build().unwrap();
2089
2090        let filter_array = BooleanArray::from(vec![true, false, true, false]);
2091        let c = filter(&array, &filter_array).unwrap();
2092        let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
2093
2094        let mut builder = UnionBuilder::new_sparse();
2095        builder.append::<Int32Type>("A", 1).unwrap();
2096        builder.append_null::<Float64Type>("B").unwrap();
2097        let expected_array = builder.build().unwrap();
2098
2099        compare_union_arrays(filtered, &expected_array);
2100    }
2101
2102    fn compare_union_arrays(union1: &UnionArray, union2: &UnionArray) {
2103        assert_eq!(union1.len(), union2.len());
2104
2105        for i in 0..union1.len() {
2106            let type_id = union1.type_id(i);
2107
2108            let slot1 = union1.value(i);
2109            let slot2 = union2.value(i);
2110
2111            assert_eq!(slot1.is_null(0), slot2.is_null(0));
2112
2113            if !slot1.is_null(0) && !slot2.is_null(0) {
2114                match type_id {
2115                    0 => {
2116                        let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
2117                        assert_eq!(slot1.len(), 1);
2118                        let value1 = slot1.value(0);
2119
2120                        let slot2 = slot2.as_any().downcast_ref::<Int32Array>().unwrap();
2121                        assert_eq!(slot2.len(), 1);
2122                        let value2 = slot2.value(0);
2123                        assert_eq!(value1, value2);
2124                    }
2125                    1 => {
2126                        let slot1 = slot1.as_any().downcast_ref::<Float64Array>().unwrap();
2127                        assert_eq!(slot1.len(), 1);
2128                        let value1 = slot1.value(0);
2129
2130                        let slot2 = slot2.as_any().downcast_ref::<Float64Array>().unwrap();
2131                        assert_eq!(slot2.len(), 1);
2132                        let value2 = slot2.value(0);
2133                        assert_eq!(value1, value2);
2134                    }
2135                    _ => unreachable!(),
2136                }
2137            }
2138        }
2139    }
2140
2141    #[test]
2142    fn test_filter_struct() {
2143        let predicate = BooleanArray::from(vec![true, false, true, false]);
2144
2145        let a = Arc::new(StringArray::from(vec!["hello", " ", "world", "!"]));
2146        let a_filtered = Arc::new(StringArray::from(vec!["hello", "world"]));
2147
2148        let b = Arc::new(Int32Array::from(vec![5, 6, 7, 8]));
2149        let b_filtered = Arc::new(Int32Array::from(vec![5, 7]));
2150
2151        let null_mask = NullBuffer::from(vec![true, false, false, true]);
2152        let null_mask_filtered = NullBuffer::from(vec![true, false]);
2153
2154        let a_field = Field::new("a", DataType::Utf8, false);
2155        let b_field = Field::new("b", DataType::Int32, false);
2156
2157        let array = StructArray::new(vec![a_field.clone()].into(), vec![a.clone()], None);
2158        let expected =
2159            StructArray::new(vec![a_field.clone()].into(), vec![a_filtered.clone()], None);
2160
2161        let result = filter(&array, &predicate).unwrap();
2162
2163        assert_eq!(result.to_data(), expected.to_data());
2164
2165        let array = StructArray::new(
2166            vec![a_field.clone()].into(),
2167            vec![a.clone()],
2168            Some(null_mask.clone()),
2169        );
2170        let expected = StructArray::new(
2171            vec![a_field.clone()].into(),
2172            vec![a_filtered.clone()],
2173            Some(null_mask_filtered.clone()),
2174        );
2175
2176        let result = filter(&array, &predicate).unwrap();
2177
2178        assert_eq!(result.to_data(), expected.to_data());
2179
2180        let array = StructArray::new(
2181            vec![a_field.clone(), b_field.clone()].into(),
2182            vec![a.clone(), b.clone()],
2183            None,
2184        );
2185        let expected = StructArray::new(
2186            vec![a_field.clone(), b_field.clone()].into(),
2187            vec![a_filtered.clone(), b_filtered.clone()],
2188            None,
2189        );
2190
2191        let result = filter(&array, &predicate).unwrap();
2192
2193        assert_eq!(result.to_data(), expected.to_data());
2194
2195        let array = StructArray::new(
2196            vec![a_field.clone(), b_field.clone()].into(),
2197            vec![a.clone(), b.clone()],
2198            Some(null_mask.clone()),
2199        );
2200
2201        let expected = StructArray::new(
2202            vec![a_field.clone(), b_field.clone()].into(),
2203            vec![a_filtered.clone(), b_filtered.clone()],
2204            Some(null_mask_filtered.clone()),
2205        );
2206
2207        let result = filter(&array, &predicate).unwrap();
2208
2209        assert_eq!(result.to_data(), expected.to_data());
2210    }
2211
2212    #[test]
2213    fn test_filter_empty_struct() {
2214        /*
2215            "a": {
2216                "b": int64,
2217                "c": {}
2218            },
2219        */
2220        let fields = arrow_schema::Field::new(
2221            "a",
2222            arrow_schema::DataType::Struct(arrow_schema::Fields::from(vec![
2223                arrow_schema::Field::new("b", arrow_schema::DataType::Int64, true),
2224                arrow_schema::Field::new(
2225                    "c",
2226                    arrow_schema::DataType::Struct(arrow_schema::Fields::empty()),
2227                    true,
2228                ),
2229            ])),
2230            true,
2231        );
2232
2233        /* Test record
2234            {"a":{"c": {}}}
2235            {"a":{"c": {}}}
2236            {"a":{"c": {}}}
2237        */
2238
2239        // Create the record batch with the nested struct array
2240        let schema = Arc::new(Schema::new(vec![fields]));
2241
2242        let b = Arc::new(Int64Array::from(vec![None, None, None]));
2243        let c = Arc::new(StructArray::new_empty_fields(
2244            3,
2245            Some(NullBuffer::from(vec![true, true, true])),
2246        ));
2247        let a = StructArray::new(
2248            vec![
2249                Field::new("b", DataType::Int64, true),
2250                Field::new("c", DataType::Struct(Fields::empty()), true),
2251            ]
2252            .into(),
2253            vec![b.clone(), c.clone()],
2254            Some(NullBuffer::from(vec![true, true, true])),
2255        );
2256        let record_batch = RecordBatch::try_new(schema, vec![Arc::new(a)]).unwrap();
2257        println!("{record_batch:?}");
2258
2259        // Apply the filter
2260        let predicate = BooleanArray::from(vec![true, false, true]);
2261        let filtered_batch = filter_record_batch(&record_batch, &predicate).unwrap();
2262
2263        // The filtered batch should have 2 rows (the 1st and 3rd)
2264        assert_eq!(filtered_batch.num_rows(), 2);
2265    }
2266
2267    #[test]
2268    #[should_panic]
2269    fn test_filter_bits_too_large() {
2270        let buffer = BooleanBuffer::from(vec![false; 8]);
2271        let predicate = BooleanArray::from(vec![true; 9]);
2272        let filter = FilterBuilder::new(&predicate).build();
2273        filter_bits(&buffer, &filter);
2274    }
2275
2276    #[test]
2277    #[should_panic]
2278    fn test_filter_native_too_large() {
2279        let values = vec![1; 8];
2280        let predicate = BooleanArray::from(vec![false; 9]);
2281        let filter = FilterBuilder::new(&predicate).build();
2282        filter_native(&values, &filter);
2283    }
2284}