use arrow::array::{Array, BooleanArray};
use std::cmp::Ordering;
use std::ops::Range;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct RowSelector {
pub row_count: usize,
pub skip: bool,
}
impl RowSelector {
pub fn select(row_count: usize) -> Self {
Self {
row_count,
skip: false,
}
}
pub fn skip(row_count: usize) -> Self {
Self {
row_count,
skip: true,
}
}
}
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub struct RowSelection {
selectors: Vec<RowSelector>,
}
impl RowSelection {
pub fn new() -> Self {
Self::default()
}
pub fn from_filters(filters: &[BooleanArray]) -> Self {
let mut next_offset = 0;
let total_rows = filters.iter().map(|x| x.len()).sum();
let iter = filters.iter().flat_map(|filter| {
let offset = next_offset;
next_offset += filter.len();
assert_eq!(
filter.null_count(),
0,
"filter arrays must not contain nulls"
);
let mut ranges = vec![];
let mut start = None;
for (idx, value) in filter.iter().enumerate() {
match (value, start) {
(Some(true), None) => start = Some(idx),
(Some(false), Some(s)) | (None, Some(s)) => {
ranges.push(s + offset..idx + offset);
start = None;
}
_ => {}
}
}
if let Some(s) = start {
ranges.push(s + offset..filter.len() + offset);
}
ranges
});
Self::from_consecutive_ranges(iter, total_rows)
}
pub fn from_consecutive_ranges<I: Iterator<Item = Range<usize>>>(
ranges: I,
total_rows: usize,
) -> Self {
let mut selectors: Vec<RowSelector> = Vec::with_capacity(ranges.size_hint().0);
let mut last_end = 0;
for range in ranges {
let len = range.end - range.start;
if len == 0 {
continue;
}
match range.start.cmp(&last_end) {
Ordering::Equal => {
match selectors.last_mut() {
Some(last) if !last.skip => {
last.row_count = last.row_count.checked_add(len).unwrap()
}
_ => selectors.push(RowSelector::select(len)),
}
}
Ordering::Greater => {
selectors.push(RowSelector::skip(range.start - last_end));
selectors.push(RowSelector::select(len));
}
Ordering::Less => {
panic!("ranges must be provided in order and must not overlap")
}
}
last_end = range.end;
}
if last_end < total_rows {
selectors.push(RowSelector::skip(total_rows - last_end));
}
Self { selectors }
}
pub fn select_all(row_count: usize) -> Self {
if row_count == 0 {
return Self::default();
}
Self {
selectors: vec![RowSelector::select(row_count)],
}
}
pub fn skip_all(row_count: usize) -> Self {
if row_count == 0 {
return Self::default();
}
Self {
selectors: vec![RowSelector::skip(row_count)],
}
}
pub fn row_count(&self) -> usize {
self.selectors.iter().map(|s| s.row_count).sum()
}
pub fn selected_row_count(&self) -> usize {
self.selectors
.iter()
.filter(|s| !s.skip)
.map(|s| s.row_count)
.sum()
}
pub fn skipped_row_count(&self) -> usize {
self.selectors
.iter()
.filter(|s| s.skip)
.map(|s| s.row_count)
.sum()
}
pub fn selects_any(&self) -> bool {
self.selectors.iter().any(|s| !s.skip)
}
pub fn iter(&self) -> impl Iterator<Item = &RowSelector> {
self.selectors.iter()
}
pub fn selectors(&self) -> &[RowSelector] {
&self.selectors
}
pub fn split_off(&mut self, row_count: usize) -> Self {
let mut total_count = 0;
let find = self.selectors.iter().position(|selector| {
total_count += selector.row_count;
total_count > row_count
});
let split_idx = match find {
Some(idx) => idx,
None => {
let selectors = std::mem::take(&mut self.selectors);
return Self { selectors };
}
};
let mut remaining = self.selectors.split_off(split_idx);
let next = remaining.first_mut().unwrap();
let overflow = total_count - row_count;
if next.row_count != overflow {
self.selectors.push(RowSelector {
row_count: next.row_count - overflow,
skip: next.skip,
});
}
next.row_count = overflow;
std::mem::swap(&mut remaining, &mut self.selectors);
Self {
selectors: remaining,
}
}
pub fn from_row_group_filter(
row_group_filter: &[bool],
rows_per_group: usize,
total_rows: usize,
) -> Self {
if row_group_filter.is_empty() {
return Self::skip_all(total_rows);
}
let num_row_groups = row_group_filter.len();
let mut selectors: Vec<RowSelector> = Vec::new();
for &keep in row_group_filter {
let selector = if keep {
RowSelector::select(rows_per_group)
} else {
RowSelector::skip(rows_per_group)
};
match selectors.last_mut() {
Some(last) if last.skip == selector.skip => {
last.row_count = last.row_count.checked_add(rows_per_group).unwrap();
}
_ => selectors.push(selector),
}
}
let covered_rows = num_row_groups * rows_per_group;
if covered_rows < total_rows {
let remaining = total_rows - covered_rows;
match selectors.last_mut() {
Some(last) if last.skip => {
last.row_count = last.row_count.checked_add(remaining).unwrap();
}
_ => selectors.push(RowSelector::skip(remaining)),
}
}
Self { selectors }
}
pub fn and_then(&self, other: &Self) -> Self {
let mut selectors = vec![];
let mut first = self.selectors.iter().cloned().peekable();
let mut second = other.selectors.iter().cloned().peekable();
let mut to_skip = 0;
while let Some(b) = second.peek_mut() {
let a = first
.peek_mut()
.expect("selection exceeds the number of selected rows");
if b.row_count == 0 {
second.next().unwrap();
continue;
}
if a.row_count == 0 {
first.next().unwrap();
continue;
}
if a.skip {
to_skip += a.row_count;
first.next().unwrap();
continue;
}
let skip = b.skip;
let to_process = a.row_count.min(b.row_count);
a.row_count -= to_process;
b.row_count -= to_process;
match skip {
true => to_skip += to_process,
false => {
if to_skip != 0 {
selectors.push(RowSelector::skip(to_skip));
to_skip = 0;
}
selectors.push(RowSelector::select(to_process));
}
}
}
for v in first {
if v.row_count != 0 {
assert!(
v.skip,
"selection contains less than the number of selected rows"
);
to_skip += v.row_count;
}
}
if to_skip != 0 {
selectors.push(RowSelector::skip(to_skip));
}
Self { selectors }
}
}
impl From<Vec<RowSelector>> for RowSelection {
fn from(selectors: Vec<RowSelector>) -> Self {
let mut result: Vec<RowSelector> = Vec::new();
for selector in selectors {
if selector.row_count == 0 {
continue;
}
match result.last_mut() {
Some(last) if last.skip == selector.skip => {
last.row_count += selector.row_count;
}
_ => result.push(selector),
}
}
Self { selectors: result }
}
}
impl From<RowSelection> for Vec<RowSelector> {
fn from(selection: RowSelection) -> Self {
selection.selectors
}
}
impl FromIterator<RowSelector> for RowSelection {
fn from_iter<T: IntoIterator<Item = RowSelector>>(iter: T) -> Self {
iter.into_iter().collect::<Vec<_>>().into()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_row_selector_select() {
let selector = RowSelector::select(100);
assert_eq!(selector.row_count, 100);
assert!(!selector.skip);
}
#[test]
fn test_row_selector_skip() {
let selector = RowSelector::skip(50);
assert_eq!(selector.row_count, 50);
assert!(selector.skip);
}
#[test]
fn test_row_selection_from_consecutive_ranges() {
let selection = RowSelection::from_consecutive_ranges(vec![5..10, 15..20].into_iter(), 25);
let expected = vec![
RowSelector::skip(5),
RowSelector::select(5),
RowSelector::skip(5),
RowSelector::select(5),
RowSelector::skip(5),
];
assert_eq!(selection.selectors, expected);
assert_eq!(selection.row_count(), 25);
assert_eq!(selection.selected_row_count(), 10);
assert_eq!(selection.skipped_row_count(), 15);
}
#[test]
fn test_row_selection_consolidation() {
let selectors = vec![
RowSelector::skip(5),
RowSelector::skip(5),
RowSelector::select(10),
RowSelector::select(5),
];
let selection: RowSelection = selectors.into();
let expected = vec![RowSelector::skip(10), RowSelector::select(15)];
assert_eq!(selection.selectors, expected);
}
#[test]
fn test_row_selection_select_all() {
let selection = RowSelection::select_all(100);
assert_eq!(selection.row_count(), 100);
assert_eq!(selection.selected_row_count(), 100);
assert_eq!(selection.skipped_row_count(), 0);
assert!(selection.selects_any());
}
#[test]
fn test_row_selection_skip_all() {
let selection = RowSelection::skip_all(100);
assert_eq!(selection.row_count(), 100);
assert_eq!(selection.selected_row_count(), 0);
assert_eq!(selection.skipped_row_count(), 100);
assert!(!selection.selects_any());
}
#[test]
fn test_row_selection_split_off() {
let mut selection =
RowSelection::from_consecutive_ranges(vec![10..30, 40..60].into_iter(), 100);
let first = selection.split_off(35);
assert_eq!(first.row_count(), 35);
assert_eq!(selection.row_count(), 65);
assert_eq!(first.selected_row_count(), 20);
assert_eq!(selection.selected_row_count(), 20);
}
#[test]
fn test_row_selection_and_then() {
let first = RowSelection::from_consecutive_ranges(std::iter::once(5..15), 20);
let second = RowSelection::from_consecutive_ranges(std::iter::once(2..7), 10);
let result = first.and_then(&second);
assert_eq!(result.row_count(), 20);
assert_eq!(result.selected_row_count(), 5);
let expected = vec![
RowSelector::skip(7),
RowSelector::select(5),
RowSelector::skip(8),
];
assert_eq!(result.selectors, expected);
}
#[test]
fn test_row_selection_from_filters() {
use arrow::array::BooleanArray;
let filter = BooleanArray::from(vec![false, false, true, true, false]);
let selection = RowSelection::from_filters(&[filter]);
let expected = vec![
RowSelector::skip(2),
RowSelector::select(2),
RowSelector::skip(1),
];
assert_eq!(selection.selectors, expected);
}
#[test]
fn test_row_selection_empty() {
let selection = RowSelection::new();
assert_eq!(selection.row_count(), 0);
assert_eq!(selection.selected_row_count(), 0);
assert!(!selection.selects_any());
}
#[test]
#[should_panic(expected = "ranges must be provided in order")]
fn test_row_selection_out_of_order() {
RowSelection::from_consecutive_ranges(vec![10..20, 5..15].into_iter(), 25);
}
#[test]
fn test_row_selection_from_row_group_filter() {
let filter = vec![false, true, false];
let selection = RowSelection::from_row_group_filter(&filter, 10000, 30000);
let expected = vec![
RowSelector::skip(10000),
RowSelector::select(10000),
RowSelector::skip(10000),
];
assert_eq!(selection.selectors, expected);
assert_eq!(selection.row_count(), 30000);
assert_eq!(selection.selected_row_count(), 10000);
assert_eq!(selection.skipped_row_count(), 20000);
}
#[test]
fn test_row_selection_from_row_group_filter_all_keep() {
let filter = vec![true, true, true];
let selection = RowSelection::from_row_group_filter(&filter, 10000, 30000);
let expected = vec![RowSelector::select(30000)];
assert_eq!(selection.selectors, expected);
assert_eq!(selection.selected_row_count(), 30000);
}
#[test]
fn test_row_selection_from_row_group_filter_all_skip() {
let filter = vec![false, false, false];
let selection = RowSelection::from_row_group_filter(&filter, 10000, 30000);
let expected = vec![RowSelector::skip(30000)];
assert_eq!(selection.selectors, expected);
assert_eq!(selection.selected_row_count(), 0);
}
#[test]
fn test_row_selection_from_row_group_filter_merge() {
let filter = vec![false, false, true, true, false];
let selection = RowSelection::from_row_group_filter(&filter, 10000, 50000);
let expected = vec![
RowSelector::skip(20000), RowSelector::select(20000), RowSelector::skip(10000),
];
assert_eq!(selection.selectors, expected);
assert_eq!(selection.row_count(), 50000);
}
#[test]
fn test_row_selection_from_row_group_filter_remaining_rows() {
let filter = vec![true, false];
let selection = RowSelection::from_row_group_filter(&filter, 10000, 25000);
let expected = vec![
RowSelector::select(10000),
RowSelector::skip(15000), ];
assert_eq!(selection.selectors, expected);
assert_eq!(selection.row_count(), 25000);
}
#[test]
fn test_row_selection_from_row_group_filter_empty() {
let filter = vec![];
let selection = RowSelection::from_row_group_filter(&filter, 10000, 50000);
let expected = vec![RowSelector::skip(50000)];
assert_eq!(selection.selectors, expected);
}
}