1use arrow::array::{Array, BooleanArray};
24use std::cmp::Ordering;
25use std::ops::Range;
26
27#[derive(Debug, Clone, Copy, Eq, PartialEq)]
32pub struct RowSelector {
33    pub row_count: usize,
35
36    pub skip: bool,
38}
39
40impl RowSelector {
41    pub fn select(row_count: usize) -> Self {
43        Self {
44            row_count,
45            skip: false,
46        }
47    }
48
49    pub fn skip(row_count: usize) -> Self {
51        Self {
52            row_count,
53            skip: true,
54        }
55    }
56}
57
58#[derive(Debug, Clone, Default, Eq, PartialEq)]
90pub struct RowSelection {
91    selectors: Vec<RowSelector>,
92}
93
94impl RowSelection {
95    pub fn new() -> Self {
97        Self::default()
98    }
99
100    pub fn from_filters(filters: &[BooleanArray]) -> Self {
106        let mut next_offset = 0;
107        let total_rows = filters.iter().map(|x| x.len()).sum();
108
109        let iter = filters.iter().flat_map(|filter| {
110            let offset = next_offset;
111            next_offset += filter.len();
112            assert_eq!(
113                filter.null_count(),
114                0,
115                "filter arrays must not contain nulls"
116            );
117
118            let mut ranges = vec![];
120            let mut start = None;
121            for (idx, value) in filter.iter().enumerate() {
122                match (value, start) {
123                    (Some(true), None) => start = Some(idx),
124                    (Some(false), Some(s)) | (None, Some(s)) => {
125                        ranges.push(s + offset..idx + offset);
126                        start = None;
127                    }
128                    _ => {}
129                }
130            }
131            if let Some(s) = start {
132                ranges.push(s + offset..filter.len() + offset);
133            }
134            ranges
135        });
136
137        Self::from_consecutive_ranges(iter, total_rows)
138    }
139
140    pub fn from_consecutive_ranges<I: Iterator<Item = Range<usize>>>(
159        ranges: I,
160        total_rows: usize,
161    ) -> Self {
162        let mut selectors: Vec<RowSelector> = Vec::with_capacity(ranges.size_hint().0);
163        let mut last_end = 0;
164
165        for range in ranges {
166            let len = range.end - range.start;
167            if len == 0 {
168                continue;
169            }
170
171            match range.start.cmp(&last_end) {
172                Ordering::Equal => {
173                    match selectors.last_mut() {
175                        Some(last) if !last.skip => {
176                            last.row_count = last.row_count.checked_add(len).unwrap()
177                        }
178                        _ => selectors.push(RowSelector::select(len)),
179                    }
180                }
181                Ordering::Greater => {
182                    selectors.push(RowSelector::skip(range.start - last_end));
184                    selectors.push(RowSelector::select(len));
185                }
186                Ordering::Less => {
187                    panic!("ranges must be provided in order and must not overlap")
188                }
189            }
190            last_end = range.end;
191        }
192
193        if last_end < total_rows {
195            selectors.push(RowSelector::skip(total_rows - last_end));
196        }
197
198        Self { selectors }
199    }
200
201    pub fn select_all(row_count: usize) -> Self {
203        if row_count == 0 {
204            return Self::default();
205        }
206        Self {
207            selectors: vec![RowSelector::select(row_count)],
208        }
209    }
210
211    pub fn skip_all(row_count: usize) -> Self {
213        if row_count == 0 {
214            return Self::default();
215        }
216        Self {
217            selectors: vec![RowSelector::skip(row_count)],
218        }
219    }
220
221    pub fn row_count(&self) -> usize {
223        self.selectors.iter().map(|s| s.row_count).sum()
224    }
225
226    pub fn selected_row_count(&self) -> usize {
228        self.selectors
229            .iter()
230            .filter(|s| !s.skip)
231            .map(|s| s.row_count)
232            .sum()
233    }
234
235    pub fn skipped_row_count(&self) -> usize {
237        self.selectors
238            .iter()
239            .filter(|s| s.skip)
240            .map(|s| s.row_count)
241            .sum()
242    }
243
244    pub fn selects_any(&self) -> bool {
246        self.selectors.iter().any(|s| !s.skip)
247    }
248
249    pub fn iter(&self) -> impl Iterator<Item = &RowSelector> {
251        self.selectors.iter()
252    }
253
254    pub fn selectors(&self) -> &[RowSelector] {
256        &self.selectors
257    }
258
259    pub fn split_off(&mut self, row_count: usize) -> Self {
279        let mut total_count = 0;
280
281        let find = self.selectors.iter().position(|selector| {
283            total_count += selector.row_count;
284            total_count > row_count
285        });
286
287        let split_idx = match find {
288            Some(idx) => idx,
289            None => {
290                let selectors = std::mem::take(&mut self.selectors);
292                return Self { selectors };
293            }
294        };
295
296        let mut remaining = self.selectors.split_off(split_idx);
297
298        let next = remaining.first_mut().unwrap();
300        let overflow = total_count - row_count;
301
302        if next.row_count != overflow {
303            self.selectors.push(RowSelector {
304                row_count: next.row_count - overflow,
305                skip: next.skip,
306            });
307        }
308        next.row_count = overflow;
309
310        std::mem::swap(&mut remaining, &mut self.selectors);
311        Self {
312            selectors: remaining,
313        }
314    }
315
316    pub fn and_then(&self, other: &Self) -> Self {
326        let mut selectors = vec![];
327        let mut first = self.selectors.iter().cloned().peekable();
328        let mut second = other.selectors.iter().cloned().peekable();
329
330        let mut to_skip = 0;
331        while let Some(b) = second.peek_mut() {
332            let a = first
333                .peek_mut()
334                .expect("selection exceeds the number of selected rows");
335
336            if b.row_count == 0 {
337                second.next().unwrap();
338                continue;
339            }
340
341            if a.row_count == 0 {
342                first.next().unwrap();
343                continue;
344            }
345
346            if a.skip {
347                to_skip += a.row_count;
349                first.next().unwrap();
350                continue;
351            }
352
353            let skip = b.skip;
354            let to_process = a.row_count.min(b.row_count);
355
356            a.row_count -= to_process;
357            b.row_count -= to_process;
358
359            match skip {
360                true => to_skip += to_process,
361                false => {
362                    if to_skip != 0 {
363                        selectors.push(RowSelector::skip(to_skip));
364                        to_skip = 0;
365                    }
366                    selectors.push(RowSelector::select(to_process));
367                }
368            }
369        }
370
371        for v in first {
373            if v.row_count != 0 {
374                assert!(
375                    v.skip,
376                    "selection contains less than the number of selected rows"
377                );
378                to_skip += v.row_count;
379            }
380        }
381
382        if to_skip != 0 {
383            selectors.push(RowSelector::skip(to_skip));
384        }
385
386        Self { selectors }
387    }
388}
389
390impl From<Vec<RowSelector>> for RowSelection {
391    fn from(selectors: Vec<RowSelector>) -> Self {
392        let mut result: Vec<RowSelector> = Vec::new();
393        for selector in selectors {
394            if selector.row_count == 0 {
395                continue;
396            }
397            match result.last_mut() {
398                Some(last) if last.skip == selector.skip => {
399                    last.row_count += selector.row_count;
400                }
401                _ => result.push(selector),
402            }
403        }
404        Self { selectors: result }
405    }
406}
407
408impl From<RowSelection> for Vec<RowSelector> {
409    fn from(selection: RowSelection) -> Self {
410        selection.selectors
411    }
412}
413
414impl FromIterator<RowSelector> for RowSelection {
415    fn from_iter<T: IntoIterator<Item = RowSelector>>(iter: T) -> Self {
416        iter.into_iter().collect::<Vec<_>>().into()
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::*;
423
424    #[test]
425    fn test_row_selector_select() {
426        let selector = RowSelector::select(100);
427        assert_eq!(selector.row_count, 100);
428        assert!(!selector.skip);
429    }
430
431    #[test]
432    fn test_row_selector_skip() {
433        let selector = RowSelector::skip(50);
434        assert_eq!(selector.row_count, 50);
435        assert!(selector.skip);
436    }
437
438    #[test]
439    fn test_row_selection_from_consecutive_ranges() {
440        let selection = RowSelection::from_consecutive_ranges(vec![5..10, 15..20].into_iter(), 25);
441
442        let expected = vec![
443            RowSelector::skip(5),
444            RowSelector::select(5),
445            RowSelector::skip(5),
446            RowSelector::select(5),
447            RowSelector::skip(5),
448        ];
449
450        assert_eq!(selection.selectors, expected);
451        assert_eq!(selection.row_count(), 25);
452        assert_eq!(selection.selected_row_count(), 10);
453        assert_eq!(selection.skipped_row_count(), 15);
454    }
455
456    #[test]
457    fn test_row_selection_consolidation() {
458        let selectors = vec![
459            RowSelector::skip(5),
460            RowSelector::skip(5),
461            RowSelector::select(10),
462            RowSelector::select(5),
463        ];
464
465        let selection: RowSelection = selectors.into();
466
467        let expected = vec![RowSelector::skip(10), RowSelector::select(15)];
468
469        assert_eq!(selection.selectors, expected);
470    }
471
472    #[test]
473    fn test_row_selection_select_all() {
474        let selection = RowSelection::select_all(100);
475        assert_eq!(selection.row_count(), 100);
476        assert_eq!(selection.selected_row_count(), 100);
477        assert_eq!(selection.skipped_row_count(), 0);
478        assert!(selection.selects_any());
479    }
480
481    #[test]
482    fn test_row_selection_skip_all() {
483        let selection = RowSelection::skip_all(100);
484        assert_eq!(selection.row_count(), 100);
485        assert_eq!(selection.selected_row_count(), 0);
486        assert_eq!(selection.skipped_row_count(), 100);
487        assert!(!selection.selects_any());
488    }
489
490    #[test]
491    fn test_row_selection_split_off() {
492        let mut selection =
493            RowSelection::from_consecutive_ranges(vec![10..30, 40..60].into_iter(), 100);
494
495        let first = selection.split_off(35);
496
497        assert_eq!(first.row_count(), 35);
498        assert_eq!(selection.row_count(), 65);
499
500        assert_eq!(first.selected_row_count(), 20);
502
503        assert_eq!(selection.selected_row_count(), 20);
505    }
506
507    #[test]
508    fn test_row_selection_and_then() {
509        let first = RowSelection::from_consecutive_ranges(std::iter::once(5..15), 20);
511
512        let second = RowSelection::from_consecutive_ranges(std::iter::once(2..7), 10);
514
515        let result = first.and_then(&second);
516
517        assert_eq!(result.row_count(), 20);
519        assert_eq!(result.selected_row_count(), 5);
520
521        let expected = vec![
522            RowSelector::skip(7),
523            RowSelector::select(5),
524            RowSelector::skip(8),
525        ];
526        assert_eq!(result.selectors, expected);
527    }
528
529    #[test]
530    fn test_row_selection_from_filters() {
531        use arrow::array::BooleanArray;
532
533        let filter = BooleanArray::from(vec![false, false, true, true, false]);
535
536        let selection = RowSelection::from_filters(&[filter]);
537
538        let expected = vec![
539            RowSelector::skip(2),
540            RowSelector::select(2),
541            RowSelector::skip(1),
542        ];
543
544        assert_eq!(selection.selectors, expected);
545    }
546
547    #[test]
548    fn test_row_selection_empty() {
549        let selection = RowSelection::new();
550        assert_eq!(selection.row_count(), 0);
551        assert_eq!(selection.selected_row_count(), 0);
552        assert!(!selection.selects_any());
553    }
554
555    #[test]
556    #[should_panic(expected = "ranges must be provided in order")]
557    fn test_row_selection_out_of_order() {
558        RowSelection::from_consecutive_ranges(vec![10..20, 5..15].into_iter(), 25);
559    }
560}