1use arrow::array::{Array, BooleanArray};
24use std::cmp::Ordering;
25use std::ops::Range;
26
27#[derive(Debug, Clone, Copy, Eq, PartialEq)]
32pub struct RowSelector {
33 pub row_count: usize,
35
36 pub skip: bool,
38}
39
40impl RowSelector {
41 pub fn select(row_count: usize) -> Self {
43 Self {
44 row_count,
45 skip: false,
46 }
47 }
48
49 pub fn skip(row_count: usize) -> Self {
51 Self {
52 row_count,
53 skip: true,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Default, Eq, PartialEq)]
90pub struct RowSelection {
91 selectors: Vec<RowSelector>,
92}
93
94impl RowSelection {
95 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn from_filters(filters: &[BooleanArray]) -> Self {
106 let mut next_offset = 0;
107 let total_rows = filters.iter().map(|x| x.len()).sum();
108
109 let iter = filters.iter().flat_map(|filter| {
110 let offset = next_offset;
111 next_offset += filter.len();
112 assert_eq!(
113 filter.null_count(),
114 0,
115 "filter arrays must not contain nulls"
116 );
117
118 let mut ranges = vec![];
120 let mut start = None;
121 for (idx, value) in filter.iter().enumerate() {
122 match (value, start) {
123 (Some(true), None) => start = Some(idx),
124 (Some(false), Some(s)) | (None, Some(s)) => {
125 ranges.push(s + offset..idx + offset);
126 start = None;
127 }
128 _ => {}
129 }
130 }
131 if let Some(s) = start {
132 ranges.push(s + offset..filter.len() + offset);
133 }
134 ranges
135 });
136
137 Self::from_consecutive_ranges(iter, total_rows)
138 }
139
140 pub fn from_consecutive_ranges<I: Iterator<Item = Range<usize>>>(
159 ranges: I,
160 total_rows: usize,
161 ) -> Self {
162 let mut selectors: Vec<RowSelector> = Vec::with_capacity(ranges.size_hint().0);
163 let mut last_end = 0;
164
165 for range in ranges {
166 let len = range.end - range.start;
167 if len == 0 {
168 continue;
169 }
170
171 match range.start.cmp(&last_end) {
172 Ordering::Equal => {
173 match selectors.last_mut() {
175 Some(last) if !last.skip => {
176 last.row_count = last.row_count.checked_add(len).unwrap()
177 }
178 _ => selectors.push(RowSelector::select(len)),
179 }
180 }
181 Ordering::Greater => {
182 selectors.push(RowSelector::skip(range.start - last_end));
184 selectors.push(RowSelector::select(len));
185 }
186 Ordering::Less => {
187 panic!("ranges must be provided in order and must not overlap")
188 }
189 }
190 last_end = range.end;
191 }
192
193 if last_end < total_rows {
195 selectors.push(RowSelector::skip(total_rows - last_end));
196 }
197
198 Self { selectors }
199 }
200
201 pub fn select_all(row_count: usize) -> Self {
203 if row_count == 0 {
204 return Self::default();
205 }
206 Self {
207 selectors: vec![RowSelector::select(row_count)],
208 }
209 }
210
211 pub fn skip_all(row_count: usize) -> Self {
213 if row_count == 0 {
214 return Self::default();
215 }
216 Self {
217 selectors: vec![RowSelector::skip(row_count)],
218 }
219 }
220
221 pub fn row_count(&self) -> usize {
223 self.selectors.iter().map(|s| s.row_count).sum()
224 }
225
226 pub fn selected_row_count(&self) -> usize {
228 self.selectors
229 .iter()
230 .filter(|s| !s.skip)
231 .map(|s| s.row_count)
232 .sum()
233 }
234
235 pub fn skipped_row_count(&self) -> usize {
237 self.selectors
238 .iter()
239 .filter(|s| s.skip)
240 .map(|s| s.row_count)
241 .sum()
242 }
243
244 pub fn selects_any(&self) -> bool {
246 self.selectors.iter().any(|s| !s.skip)
247 }
248
249 pub fn iter(&self) -> impl Iterator<Item = &RowSelector> {
251 self.selectors.iter()
252 }
253
254 pub fn selectors(&self) -> &[RowSelector] {
256 &self.selectors
257 }
258
259 pub fn split_off(&mut self, row_count: usize) -> Self {
279 let mut total_count = 0;
280
281 let find = self.selectors.iter().position(|selector| {
283 total_count += selector.row_count;
284 total_count > row_count
285 });
286
287 let split_idx = match find {
288 Some(idx) => idx,
289 None => {
290 let selectors = std::mem::take(&mut self.selectors);
292 return Self { selectors };
293 }
294 };
295
296 let mut remaining = self.selectors.split_off(split_idx);
297
298 let next = remaining.first_mut().unwrap();
300 let overflow = total_count - row_count;
301
302 if next.row_count != overflow {
303 self.selectors.push(RowSelector {
304 row_count: next.row_count - overflow,
305 skip: next.skip,
306 });
307 }
308 next.row_count = overflow;
309
310 std::mem::swap(&mut remaining, &mut self.selectors);
311 Self {
312 selectors: remaining,
313 }
314 }
315
316 pub fn from_row_group_filter(
349 row_group_filter: &[bool],
350 rows_per_group: usize,
351 total_rows: usize,
352 ) -> Self {
353 if row_group_filter.is_empty() {
354 return Self::skip_all(total_rows);
355 }
356
357 let num_row_groups = row_group_filter.len();
358 let mut selectors: Vec<RowSelector> = Vec::new();
359
360 for &keep in row_group_filter {
361 let selector = if keep {
362 RowSelector::select(rows_per_group)
363 } else {
364 RowSelector::skip(rows_per_group)
365 };
366
367 match selectors.last_mut() {
369 Some(last) if last.skip == selector.skip => {
370 last.row_count = last.row_count.checked_add(rows_per_group).unwrap();
371 }
372 _ => selectors.push(selector),
373 }
374 }
375
376 let covered_rows = num_row_groups * rows_per_group;
378 if covered_rows < total_rows {
379 let remaining = total_rows - covered_rows;
380 match selectors.last_mut() {
382 Some(last) if last.skip => {
383 last.row_count = last.row_count.checked_add(remaining).unwrap();
384 }
385 _ => selectors.push(RowSelector::skip(remaining)),
386 }
387 }
388
389 Self { selectors }
390 }
391
392 pub fn and_then(&self, other: &Self) -> Self {
402 let mut selectors = vec![];
403 let mut first = self.selectors.iter().cloned().peekable();
404 let mut second = other.selectors.iter().cloned().peekable();
405
406 let mut to_skip = 0;
407 while let Some(b) = second.peek_mut() {
408 let a = first
409 .peek_mut()
410 .expect("selection exceeds the number of selected rows");
411
412 if b.row_count == 0 {
413 second.next().unwrap();
414 continue;
415 }
416
417 if a.row_count == 0 {
418 first.next().unwrap();
419 continue;
420 }
421
422 if a.skip {
423 to_skip += a.row_count;
425 first.next().unwrap();
426 continue;
427 }
428
429 let skip = b.skip;
430 let to_process = a.row_count.min(b.row_count);
431
432 a.row_count -= to_process;
433 b.row_count -= to_process;
434
435 match skip {
436 true => to_skip += to_process,
437 false => {
438 if to_skip != 0 {
439 selectors.push(RowSelector::skip(to_skip));
440 to_skip = 0;
441 }
442 selectors.push(RowSelector::select(to_process));
443 }
444 }
445 }
446
447 for v in first {
449 if v.row_count != 0 {
450 assert!(
451 v.skip,
452 "selection contains less than the number of selected rows"
453 );
454 to_skip += v.row_count;
455 }
456 }
457
458 if to_skip != 0 {
459 selectors.push(RowSelector::skip(to_skip));
460 }
461
462 Self { selectors }
463 }
464}
465
466impl From<Vec<RowSelector>> for RowSelection {
467 fn from(selectors: Vec<RowSelector>) -> Self {
468 let mut result: Vec<RowSelector> = Vec::new();
469 for selector in selectors {
470 if selector.row_count == 0 {
471 continue;
472 }
473 match result.last_mut() {
474 Some(last) if last.skip == selector.skip => {
475 last.row_count += selector.row_count;
476 }
477 _ => result.push(selector),
478 }
479 }
480 Self { selectors: result }
481 }
482}
483
484impl From<RowSelection> for Vec<RowSelector> {
485 fn from(selection: RowSelection) -> Self {
486 selection.selectors
487 }
488}
489
490impl FromIterator<RowSelector> for RowSelection {
491 fn from_iter<T: IntoIterator<Item = RowSelector>>(iter: T) -> Self {
492 iter.into_iter().collect::<Vec<_>>().into()
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_row_selector_select() {
502 let selector = RowSelector::select(100);
503 assert_eq!(selector.row_count, 100);
504 assert!(!selector.skip);
505 }
506
507 #[test]
508 fn test_row_selector_skip() {
509 let selector = RowSelector::skip(50);
510 assert_eq!(selector.row_count, 50);
511 assert!(selector.skip);
512 }
513
514 #[test]
515 fn test_row_selection_from_consecutive_ranges() {
516 let selection = RowSelection::from_consecutive_ranges(vec![5..10, 15..20].into_iter(), 25);
517
518 let expected = vec![
519 RowSelector::skip(5),
520 RowSelector::select(5),
521 RowSelector::skip(5),
522 RowSelector::select(5),
523 RowSelector::skip(5),
524 ];
525
526 assert_eq!(selection.selectors, expected);
527 assert_eq!(selection.row_count(), 25);
528 assert_eq!(selection.selected_row_count(), 10);
529 assert_eq!(selection.skipped_row_count(), 15);
530 }
531
532 #[test]
533 fn test_row_selection_consolidation() {
534 let selectors = vec![
535 RowSelector::skip(5),
536 RowSelector::skip(5),
537 RowSelector::select(10),
538 RowSelector::select(5),
539 ];
540
541 let selection: RowSelection = selectors.into();
542
543 let expected = vec![RowSelector::skip(10), RowSelector::select(15)];
544
545 assert_eq!(selection.selectors, expected);
546 }
547
548 #[test]
549 fn test_row_selection_select_all() {
550 let selection = RowSelection::select_all(100);
551 assert_eq!(selection.row_count(), 100);
552 assert_eq!(selection.selected_row_count(), 100);
553 assert_eq!(selection.skipped_row_count(), 0);
554 assert!(selection.selects_any());
555 }
556
557 #[test]
558 fn test_row_selection_skip_all() {
559 let selection = RowSelection::skip_all(100);
560 assert_eq!(selection.row_count(), 100);
561 assert_eq!(selection.selected_row_count(), 0);
562 assert_eq!(selection.skipped_row_count(), 100);
563 assert!(!selection.selects_any());
564 }
565
566 #[test]
567 fn test_row_selection_split_off() {
568 let mut selection =
569 RowSelection::from_consecutive_ranges(vec![10..30, 40..60].into_iter(), 100);
570
571 let first = selection.split_off(35);
572
573 assert_eq!(first.row_count(), 35);
574 assert_eq!(selection.row_count(), 65);
575
576 assert_eq!(first.selected_row_count(), 20);
578
579 assert_eq!(selection.selected_row_count(), 20);
581 }
582
583 #[test]
584 fn test_row_selection_and_then() {
585 let first = RowSelection::from_consecutive_ranges(std::iter::once(5..15), 20);
587
588 let second = RowSelection::from_consecutive_ranges(std::iter::once(2..7), 10);
590
591 let result = first.and_then(&second);
592
593 assert_eq!(result.row_count(), 20);
595 assert_eq!(result.selected_row_count(), 5);
596
597 let expected = vec![
598 RowSelector::skip(7),
599 RowSelector::select(5),
600 RowSelector::skip(8),
601 ];
602 assert_eq!(result.selectors, expected);
603 }
604
605 #[test]
606 fn test_row_selection_from_filters() {
607 use arrow::array::BooleanArray;
608
609 let filter = BooleanArray::from(vec![false, false, true, true, false]);
611
612 let selection = RowSelection::from_filters(&[filter]);
613
614 let expected = vec![
615 RowSelector::skip(2),
616 RowSelector::select(2),
617 RowSelector::skip(1),
618 ];
619
620 assert_eq!(selection.selectors, expected);
621 }
622
623 #[test]
624 fn test_row_selection_empty() {
625 let selection = RowSelection::new();
626 assert_eq!(selection.row_count(), 0);
627 assert_eq!(selection.selected_row_count(), 0);
628 assert!(!selection.selects_any());
629 }
630
631 #[test]
632 #[should_panic(expected = "ranges must be provided in order")]
633 fn test_row_selection_out_of_order() {
634 RowSelection::from_consecutive_ranges(vec![10..20, 5..15].into_iter(), 25);
635 }
636
637 #[test]
638 fn test_row_selection_from_row_group_filter() {
639 let filter = vec![false, true, false];
641 let selection = RowSelection::from_row_group_filter(&filter, 10000, 30000);
642
643 let expected = vec![
644 RowSelector::skip(10000),
645 RowSelector::select(10000),
646 RowSelector::skip(10000),
647 ];
648
649 assert_eq!(selection.selectors, expected);
650 assert_eq!(selection.row_count(), 30000);
651 assert_eq!(selection.selected_row_count(), 10000);
652 assert_eq!(selection.skipped_row_count(), 20000);
653 }
654
655 #[test]
656 fn test_row_selection_from_row_group_filter_all_keep() {
657 let filter = vec![true, true, true];
659 let selection = RowSelection::from_row_group_filter(&filter, 10000, 30000);
660
661 let expected = vec![RowSelector::select(30000)];
662
663 assert_eq!(selection.selectors, expected);
664 assert_eq!(selection.selected_row_count(), 30000);
665 }
666
667 #[test]
668 fn test_row_selection_from_row_group_filter_all_skip() {
669 let filter = vec![false, false, false];
671 let selection = RowSelection::from_row_group_filter(&filter, 10000, 30000);
672
673 let expected = vec![RowSelector::skip(30000)];
674
675 assert_eq!(selection.selectors, expected);
676 assert_eq!(selection.selected_row_count(), 0);
677 }
678
679 #[test]
680 fn test_row_selection_from_row_group_filter_merge() {
681 let filter = vec![false, false, true, true, false];
683 let selection = RowSelection::from_row_group_filter(&filter, 10000, 50000);
684
685 let expected = vec![
687 RowSelector::skip(20000), RowSelector::select(20000), RowSelector::skip(10000),
690 ];
691
692 assert_eq!(selection.selectors, expected);
693 assert_eq!(selection.row_count(), 50000);
694 }
695
696 #[test]
697 fn test_row_selection_from_row_group_filter_remaining_rows() {
698 let filter = vec![true, false];
700 let selection = RowSelection::from_row_group_filter(&filter, 10000, 25000);
701
702 let expected = vec![
704 RowSelector::select(10000),
705 RowSelector::skip(15000), ];
707
708 assert_eq!(selection.selectors, expected);
709 assert_eq!(selection.row_count(), 25000);
710 }
711
712 #[test]
713 fn test_row_selection_from_row_group_filter_empty() {
714 let filter = vec![];
716 let selection = RowSelection::from_row_group_filter(&filter, 10000, 50000);
717
718 let expected = vec![RowSelector::skip(50000)];
720 assert_eq!(selection.selectors, expected);
721 }
722}