1use arrow::array::{Array, BooleanArray};
24use std::cmp::Ordering;
25use std::ops::Range;
26
27#[derive(Debug, Clone, Copy, Eq, PartialEq)]
32pub struct RowSelector {
33 pub row_count: usize,
35
36 pub skip: bool,
38}
39
40impl RowSelector {
41 pub fn select(row_count: usize) -> Self {
43 Self {
44 row_count,
45 skip: false,
46 }
47 }
48
49 pub fn skip(row_count: usize) -> Self {
51 Self {
52 row_count,
53 skip: true,
54 }
55 }
56}
57
58#[derive(Debug, Clone, Default, Eq, PartialEq)]
90pub struct RowSelection {
91 selectors: Vec<RowSelector>,
92}
93
94impl RowSelection {
95 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn from_filters(filters: &[BooleanArray]) -> Self {
106 let mut next_offset = 0;
107 let total_rows = filters.iter().map(|x| x.len()).sum();
108
109 let iter = filters.iter().flat_map(|filter| {
110 let offset = next_offset;
111 next_offset += filter.len();
112 assert_eq!(
113 filter.null_count(),
114 0,
115 "filter arrays must not contain nulls"
116 );
117
118 let mut ranges = vec![];
120 let mut start = None;
121 for (idx, value) in filter.iter().enumerate() {
122 match (value, start) {
123 (Some(true), None) => start = Some(idx),
124 (Some(false), Some(s)) | (None, Some(s)) => {
125 ranges.push(s + offset..idx + offset);
126 start = None;
127 }
128 _ => {}
129 }
130 }
131 if let Some(s) = start {
132 ranges.push(s + offset..filter.len() + offset);
133 }
134 ranges
135 });
136
137 Self::from_consecutive_ranges(iter, total_rows)
138 }
139
140 pub fn from_consecutive_ranges<I: Iterator<Item = Range<usize>>>(
159 ranges: I,
160 total_rows: usize,
161 ) -> Self {
162 let mut selectors: Vec<RowSelector> = Vec::with_capacity(ranges.size_hint().0);
163 let mut last_end = 0;
164
165 for range in ranges {
166 let len = range.end - range.start;
167 if len == 0 {
168 continue;
169 }
170
171 match range.start.cmp(&last_end) {
172 Ordering::Equal => {
173 match selectors.last_mut() {
175 Some(last) if !last.skip => {
176 last.row_count = last.row_count.checked_add(len).unwrap()
177 }
178 _ => selectors.push(RowSelector::select(len)),
179 }
180 }
181 Ordering::Greater => {
182 selectors.push(RowSelector::skip(range.start - last_end));
184 selectors.push(RowSelector::select(len));
185 }
186 Ordering::Less => {
187 panic!("ranges must be provided in order and must not overlap")
188 }
189 }
190 last_end = range.end;
191 }
192
193 if last_end < total_rows {
195 selectors.push(RowSelector::skip(total_rows - last_end));
196 }
197
198 Self { selectors }
199 }
200
201 pub fn select_all(row_count: usize) -> Self {
203 if row_count == 0 {
204 return Self::default();
205 }
206 Self {
207 selectors: vec![RowSelector::select(row_count)],
208 }
209 }
210
211 pub fn skip_all(row_count: usize) -> Self {
213 if row_count == 0 {
214 return Self::default();
215 }
216 Self {
217 selectors: vec![RowSelector::skip(row_count)],
218 }
219 }
220
221 pub fn row_count(&self) -> usize {
223 self.selectors.iter().map(|s| s.row_count).sum()
224 }
225
226 pub fn selected_row_count(&self) -> usize {
228 self.selectors
229 .iter()
230 .filter(|s| !s.skip)
231 .map(|s| s.row_count)
232 .sum()
233 }
234
235 pub fn skipped_row_count(&self) -> usize {
237 self.selectors
238 .iter()
239 .filter(|s| s.skip)
240 .map(|s| s.row_count)
241 .sum()
242 }
243
244 pub fn selects_any(&self) -> bool {
246 self.selectors.iter().any(|s| !s.skip)
247 }
248
249 pub fn iter(&self) -> impl Iterator<Item = &RowSelector> {
251 self.selectors.iter()
252 }
253
254 pub fn selectors(&self) -> &[RowSelector] {
256 &self.selectors
257 }
258
259 pub fn split_off(&mut self, row_count: usize) -> Self {
279 let mut total_count = 0;
280
281 let find = self.selectors.iter().position(|selector| {
283 total_count += selector.row_count;
284 total_count > row_count
285 });
286
287 let split_idx = match find {
288 Some(idx) => idx,
289 None => {
290 let selectors = std::mem::take(&mut self.selectors);
292 return Self { selectors };
293 }
294 };
295
296 let mut remaining = self.selectors.split_off(split_idx);
297
298 let next = remaining.first_mut().unwrap();
300 let overflow = total_count - row_count;
301
302 if next.row_count != overflow {
303 self.selectors.push(RowSelector {
304 row_count: next.row_count - overflow,
305 skip: next.skip,
306 });
307 }
308 next.row_count = overflow;
309
310 std::mem::swap(&mut remaining, &mut self.selectors);
311 Self {
312 selectors: remaining,
313 }
314 }
315
316 pub fn and_then(&self, other: &Self) -> Self {
326 let mut selectors = vec![];
327 let mut first = self.selectors.iter().cloned().peekable();
328 let mut second = other.selectors.iter().cloned().peekable();
329
330 let mut to_skip = 0;
331 while let Some(b) = second.peek_mut() {
332 let a = first
333 .peek_mut()
334 .expect("selection exceeds the number of selected rows");
335
336 if b.row_count == 0 {
337 second.next().unwrap();
338 continue;
339 }
340
341 if a.row_count == 0 {
342 first.next().unwrap();
343 continue;
344 }
345
346 if a.skip {
347 to_skip += a.row_count;
349 first.next().unwrap();
350 continue;
351 }
352
353 let skip = b.skip;
354 let to_process = a.row_count.min(b.row_count);
355
356 a.row_count -= to_process;
357 b.row_count -= to_process;
358
359 match skip {
360 true => to_skip += to_process,
361 false => {
362 if to_skip != 0 {
363 selectors.push(RowSelector::skip(to_skip));
364 to_skip = 0;
365 }
366 selectors.push(RowSelector::select(to_process));
367 }
368 }
369 }
370
371 for v in first {
373 if v.row_count != 0 {
374 assert!(
375 v.skip,
376 "selection contains less than the number of selected rows"
377 );
378 to_skip += v.row_count;
379 }
380 }
381
382 if to_skip != 0 {
383 selectors.push(RowSelector::skip(to_skip));
384 }
385
386 Self { selectors }
387 }
388}
389
390impl From<Vec<RowSelector>> for RowSelection {
391 fn from(selectors: Vec<RowSelector>) -> Self {
392 let mut result: Vec<RowSelector> = Vec::new();
393 for selector in selectors {
394 if selector.row_count == 0 {
395 continue;
396 }
397 match result.last_mut() {
398 Some(last) if last.skip == selector.skip => {
399 last.row_count += selector.row_count;
400 }
401 _ => result.push(selector),
402 }
403 }
404 Self { selectors: result }
405 }
406}
407
408impl From<RowSelection> for Vec<RowSelector> {
409 fn from(selection: RowSelection) -> Self {
410 selection.selectors
411 }
412}
413
414impl FromIterator<RowSelector> for RowSelection {
415 fn from_iter<T: IntoIterator<Item = RowSelector>>(iter: T) -> Self {
416 iter.into_iter().collect::<Vec<_>>().into()
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_row_selector_select() {
426 let selector = RowSelector::select(100);
427 assert_eq!(selector.row_count, 100);
428 assert!(!selector.skip);
429 }
430
431 #[test]
432 fn test_row_selector_skip() {
433 let selector = RowSelector::skip(50);
434 assert_eq!(selector.row_count, 50);
435 assert!(selector.skip);
436 }
437
438 #[test]
439 fn test_row_selection_from_consecutive_ranges() {
440 let selection = RowSelection::from_consecutive_ranges(vec![5..10, 15..20].into_iter(), 25);
441
442 let expected = vec![
443 RowSelector::skip(5),
444 RowSelector::select(5),
445 RowSelector::skip(5),
446 RowSelector::select(5),
447 RowSelector::skip(5),
448 ];
449
450 assert_eq!(selection.selectors, expected);
451 assert_eq!(selection.row_count(), 25);
452 assert_eq!(selection.selected_row_count(), 10);
453 assert_eq!(selection.skipped_row_count(), 15);
454 }
455
456 #[test]
457 fn test_row_selection_consolidation() {
458 let selectors = vec![
459 RowSelector::skip(5),
460 RowSelector::skip(5),
461 RowSelector::select(10),
462 RowSelector::select(5),
463 ];
464
465 let selection: RowSelection = selectors.into();
466
467 let expected = vec![RowSelector::skip(10), RowSelector::select(15)];
468
469 assert_eq!(selection.selectors, expected);
470 }
471
472 #[test]
473 fn test_row_selection_select_all() {
474 let selection = RowSelection::select_all(100);
475 assert_eq!(selection.row_count(), 100);
476 assert_eq!(selection.selected_row_count(), 100);
477 assert_eq!(selection.skipped_row_count(), 0);
478 assert!(selection.selects_any());
479 }
480
481 #[test]
482 fn test_row_selection_skip_all() {
483 let selection = RowSelection::skip_all(100);
484 assert_eq!(selection.row_count(), 100);
485 assert_eq!(selection.selected_row_count(), 0);
486 assert_eq!(selection.skipped_row_count(), 100);
487 assert!(!selection.selects_any());
488 }
489
490 #[test]
491 fn test_row_selection_split_off() {
492 let mut selection =
493 RowSelection::from_consecutive_ranges(vec![10..30, 40..60].into_iter(), 100);
494
495 let first = selection.split_off(35);
496
497 assert_eq!(first.row_count(), 35);
498 assert_eq!(selection.row_count(), 65);
499
500 assert_eq!(first.selected_row_count(), 20);
502
503 assert_eq!(selection.selected_row_count(), 20);
505 }
506
507 #[test]
508 fn test_row_selection_and_then() {
509 let first = RowSelection::from_consecutive_ranges(std::iter::once(5..15), 20);
511
512 let second = RowSelection::from_consecutive_ranges(std::iter::once(2..7), 10);
514
515 let result = first.and_then(&second);
516
517 assert_eq!(result.row_count(), 20);
519 assert_eq!(result.selected_row_count(), 5);
520
521 let expected = vec![
522 RowSelector::skip(7),
523 RowSelector::select(5),
524 RowSelector::skip(8),
525 ];
526 assert_eq!(result.selectors, expected);
527 }
528
529 #[test]
530 fn test_row_selection_from_filters() {
531 use arrow::array::BooleanArray;
532
533 let filter = BooleanArray::from(vec![false, false, true, true, false]);
535
536 let selection = RowSelection::from_filters(&[filter]);
537
538 let expected = vec![
539 RowSelector::skip(2),
540 RowSelector::select(2),
541 RowSelector::skip(1),
542 ];
543
544 assert_eq!(selection.selectors, expected);
545 }
546
547 #[test]
548 fn test_row_selection_empty() {
549 let selection = RowSelection::new();
550 assert_eq!(selection.row_count(), 0);
551 assert_eq!(selection.selected_row_count(), 0);
552 assert!(!selection.selects_any());
553 }
554
555 #[test]
556 #[should_panic(expected = "ranges must be provided in order")]
557 fn test_row_selection_out_of_order() {
558 RowSelection::from_consecutive_ranges(vec![10..20, 5..15].into_iter(), 25);
559 }
560}