lance_core/utils/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashSet;
5use std::io::Write;
6use std::iter;
7use std::ops::{Range, RangeBounds};
8use std::{collections::BTreeMap, io::Read};
9
10use arrow_array::{Array, BinaryArray, GenericBinaryArray};
11use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer};
12use byteorder::{ReadBytesExt, WriteBytesExt};
13use deepsize::DeepSizeOf;
14use roaring::{MultiOps, RoaringBitmap, RoaringTreemap};
15
16use crate::Result;
17
18use super::address::RowAddress;
19
20/// A row id mask to select or deselect particular row ids
21///
22/// If both the allow_list and the block_list are Some then the only selected
23/// row ids are those that are in the allow_list but not in the block_list
24/// (the block_list takes precedence)
25///
26/// If both the allow_list and the block_list are None (the default) then
27/// all row ids are selected
28#[derive(Clone, Debug, Default, DeepSizeOf)]
29pub struct RowIdMask {
30    /// If Some then only these row ids are selected
31    pub allow_list: Option<RowIdTreeMap>,
32    /// If Some then these row ids are not selected.
33    pub block_list: Option<RowIdTreeMap>,
34}
35
36impl RowIdMask {
37    // Create a mask allowing all rows, this is an alias for [default]
38    pub fn all_rows() -> Self {
39        Self::default()
40    }
41
42    // Create a mask that doesn't allow anything
43    pub fn allow_nothing() -> Self {
44        Self {
45            allow_list: Some(RowIdTreeMap::new()),
46            block_list: None,
47        }
48    }
49
50    // Create a mask from an allow list
51    pub fn from_allowed(allow_list: RowIdTreeMap) -> Self {
52        Self {
53            allow_list: Some(allow_list),
54            block_list: None,
55        }
56    }
57
58    // Create a mask from a block list
59    pub fn from_block(block_list: RowIdTreeMap) -> Self {
60        Self {
61            allow_list: None,
62            block_list: Some(block_list),
63        }
64    }
65
66    /// True if the row_id is selected by the mask, false otherwise
67    pub fn selected(&self, row_id: u64) -> bool {
68        match (&self.allow_list, &self.block_list) {
69            (None, None) => true,
70            (Some(allow_list), None) => allow_list.contains(row_id),
71            (None, Some(block_list)) => !block_list.contains(row_id),
72            (Some(allow_list), Some(block_list)) => {
73                allow_list.contains(row_id) && !block_list.contains(row_id)
74            }
75        }
76    }
77
78    /// Return the indices of the input row ids that were valid
79    pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
80        let enumerated_ids = row_ids.enumerate();
81        match (&self.block_list, &self.allow_list) {
82            (Some(block_list), Some(allow_list)) => {
83                // Only take rows that are both in the allow list and not in the block list
84                enumerated_ids
85                    .filter(|(_, row_id)| {
86                        !block_list.contains(**row_id) && allow_list.contains(**row_id)
87                    })
88                    .map(|(idx, _)| idx as u64)
89                    .collect()
90            }
91            (Some(block_list), None) => {
92                // Take rows that are not in the block list
93                enumerated_ids
94                    .filter(|(_, row_id)| !block_list.contains(**row_id))
95                    .map(|(idx, _)| idx as u64)
96                    .collect()
97            }
98            (None, Some(allow_list)) => {
99                // Take rows that are in the allow list
100                enumerated_ids
101                    .filter(|(_, row_id)| allow_list.contains(**row_id))
102                    .map(|(idx, _)| idx as u64)
103                    .collect()
104            }
105            (None, None) => {
106                // We should not encounter this case because callers should
107                // check is_empty first.
108                panic!("selected_indices called but prefilter has nothing to filter with")
109            }
110        }
111    }
112
113    /// Also block the given ids
114    pub fn also_block(self, block_list: RowIdTreeMap) -> Self {
115        if block_list.is_empty() {
116            return self;
117        }
118        if let Some(existing) = self.block_list {
119            Self {
120                block_list: Some(existing | block_list),
121                allow_list: self.allow_list,
122            }
123        } else {
124            Self {
125                block_list: Some(block_list),
126                allow_list: self.allow_list,
127            }
128        }
129    }
130
131    /// Also allow the given ids
132    pub fn also_allow(self, allow_list: RowIdTreeMap) -> Self {
133        if let Some(existing) = self.allow_list {
134            Self {
135                block_list: self.block_list,
136                allow_list: Some(existing | allow_list),
137            }
138        } else {
139            Self {
140                block_list: self.block_list,
141                // allow_list = None means "all rows allowed" and so allowing
142                //              more rows is meaningless
143                allow_list: None,
144            }
145        }
146    }
147
148    /// Convert a mask into an arrow array
149    ///
150    /// A row id mask is not very arrow-compatible.  We can't make it a batch with
151    /// two columns because the block list and allow list will have different lengths.  Also,
152    /// there is no Arrow type for compressed bitmaps.
153    ///
154    /// However, we need to shove it into some kind of Arrow container to pass it along the
155    /// datafusion stream.  Perhaps, in the future, we can add row id masks as first class
156    /// types in datafusion, and this can be passed along as a mask / selection vector.
157    ///
158    /// We serialize this as a variable length binary array with two items.  The first item
159    /// is the block list and the second item is the allow list.
160    pub fn into_arrow(&self) -> Result<BinaryArray> {
161        let block_list_length = self
162            .block_list
163            .as_ref()
164            .map(|bl| bl.serialized_size())
165            .unwrap_or(0);
166        let allow_list_length = self
167            .allow_list
168            .as_ref()
169            .map(|al| al.serialized_size())
170            .unwrap_or(0);
171        let lengths = vec![block_list_length, allow_list_length];
172        let offsets = OffsetBuffer::from_lengths(lengths);
173        let mut value_bytes = vec![0; block_list_length + allow_list_length];
174        let mut validity = vec![false, false];
175        if let Some(block_list) = &self.block_list {
176            validity[0] = true;
177            block_list.serialize_into(&mut value_bytes[0..])?;
178        }
179        if let Some(allow_list) = &self.allow_list {
180            validity[1] = true;
181            allow_list.serialize_into(&mut value_bytes[block_list_length..])?;
182        }
183        let values = Buffer::from(value_bytes);
184        let nulls = NullBuffer::from(validity);
185        Ok(BinaryArray::try_new(offsets, values, Some(nulls))?)
186    }
187
188    /// Deserialize a row id mask from Arrow
189    pub fn from_arrow(array: &GenericBinaryArray<i32>) -> Result<Self> {
190        let block_list = if array.is_null(0) {
191            None
192        } else {
193            Some(RowIdTreeMap::deserialize_from(array.value(0)))
194        }
195        .transpose()?;
196
197        let allow_list = if array.is_null(1) {
198            None
199        } else {
200            Some(RowIdTreeMap::deserialize_from(array.value(1)))
201        }
202        .transpose()?;
203        Ok(Self {
204            block_list,
205            allow_list,
206        })
207    }
208
209    /// Return the maximum number of row ids that could be selected by this mask
210    ///
211    /// Will be None if there is no allow list
212    pub fn max_len(&self) -> Option<u64> {
213        if let Some(allow_list) = &self.allow_list {
214            // If there is a block list we could theoretically intersect the two
215            // but it's not clear if that is worth the effort.  Feel free to add later.
216            allow_list.len()
217        } else {
218            None
219        }
220    }
221
222    /// Iterate over the row ids that are selected by the mask
223    ///
224    /// This is only possible if there is an allow list and neither the
225    /// allow list nor the block list contain any "full fragment" blocks.
226    ///
227    /// TODO: We could probably still iterate efficiently even if the block
228    /// list contains "full fragment" blocks but that would require some
229    /// extra logic.
230    pub fn iter_ids(&self) -> Option<Box<dyn Iterator<Item = RowAddress> + '_>> {
231        if let Some(mut allow_iter) = self.allow_list.as_ref().and_then(|list| list.row_ids()) {
232            if let Some(block_list) = &self.block_list {
233                if let Some(block_iter) = block_list.row_ids() {
234                    let mut block_iter = block_iter.peekable();
235                    Some(Box::new(iter::from_fn(move || {
236                        for allow_id in allow_iter.by_ref() {
237                            while let Some(block_id) = block_iter.peek() {
238                                if *block_id >= allow_id {
239                                    break;
240                                }
241                                block_iter.next();
242                            }
243                            if let Some(block_id) = block_iter.peek() {
244                                if *block_id == allow_id {
245                                    continue;
246                                }
247                            }
248                            return Some(allow_id);
249                        }
250                        None
251                    })))
252                } else {
253                    // There is a block list but we can't iterate over it, give up
254                    None
255                }
256            } else {
257                // There is no block list, use the allow list
258                Some(Box::new(allow_iter))
259            }
260        } else {
261            None
262        }
263    }
264}
265
266impl std::ops::Not for RowIdMask {
267    type Output = Self;
268
269    fn not(self) -> Self::Output {
270        Self {
271            block_list: self.allow_list,
272            allow_list: self.block_list,
273        }
274    }
275}
276
277impl std::ops::BitAnd for RowIdMask {
278    type Output = Self;
279
280    fn bitand(self, rhs: Self) -> Self::Output {
281        let block_list = match (self.block_list, rhs.block_list) {
282            (None, None) => None,
283            (Some(lhs), None) => Some(lhs),
284            (None, Some(rhs)) => Some(rhs),
285            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
286        };
287        let allow_list = match (self.allow_list, rhs.allow_list) {
288            (None, None) => None,
289            (Some(lhs), None) => Some(lhs),
290            (None, Some(rhs)) => Some(rhs),
291            (Some(lhs), Some(rhs)) => Some(lhs & rhs),
292        };
293        Self {
294            block_list,
295            allow_list,
296        }
297    }
298}
299
300impl std::ops::BitOr for RowIdMask {
301    type Output = Self;
302
303    fn bitor(self, rhs: Self) -> Self::Output {
304        let block_list = match (self.block_list, rhs.block_list) {
305            (None, None) => None,
306            (Some(lhs), None) => Some(lhs),
307            (None, Some(rhs)) => Some(rhs),
308            (Some(lhs), Some(rhs)) => Some(lhs & rhs),
309        };
310        let allow_list = match (self.allow_list, rhs.allow_list) {
311            (None, None) => None,
312            // Remember that an allow list of None means "all rows" and
313            // so "all rows" | "some rows" is always "all rows"
314            (Some(_), None) => None,
315            (None, Some(_)) => None,
316            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
317        };
318        Self {
319            block_list,
320            allow_list,
321        }
322    }
323}
324
325/// A collection of row ids.
326///
327/// These row ids may either be stable-style (where they can be an incrementing
328/// u64 sequence) or address style, where they are a fragment id and a row offset.
329/// When address style, this supports setting entire fragments as selected,
330/// without needing to enumerate all the ids in the fragment.
331///
332/// This is similar to a [RoaringTreemap] but it is optimized for the case where
333/// entire fragments are selected or deselected.
334#[derive(Clone, Debug, Default, PartialEq, DeepSizeOf)]
335pub struct RowIdTreeMap {
336    /// The contents of the set. If there is a pair (k, Full) then the entire
337    /// fragment k is selected. If there is a pair (k, Partial(v)) then the
338    /// fragment k has the selected rows in v.
339    inner: BTreeMap<u32, RowIdSelection>,
340}
341
342#[derive(Clone, Debug, PartialEq)]
343enum RowIdSelection {
344    Full,
345    Partial(RoaringBitmap),
346}
347
348impl DeepSizeOf for RowIdSelection {
349    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
350        match self {
351            Self::Full => 0,
352            Self::Partial(bitmap) => bitmap.serialized_size(),
353        }
354    }
355}
356
357impl RowIdSelection {
358    fn union_all(selections: &[&Self]) -> Self {
359        let mut is_full = false;
360
361        let res = Self::Partial(
362            selections
363                .iter()
364                .filter_map(|selection| match selection {
365                    Self::Full => {
366                        is_full = true;
367                        None
368                    }
369                    Self::Partial(bitmap) => Some(bitmap),
370                })
371                .union(),
372        );
373
374        if is_full {
375            Self::Full
376        } else {
377            res
378        }
379    }
380}
381
382impl RowIdTreeMap {
383    /// Create an empty set
384    pub fn new() -> Self {
385        Self::default()
386    }
387
388    pub fn is_empty(&self) -> bool {
389        self.inner.is_empty()
390    }
391
392    /// The number of rows in the map
393    ///
394    /// If there are any "full fragment" items then this is unknown and None is returned
395    pub fn len(&self) -> Option<u64> {
396        self.inner
397            .values()
398            .map(|row_id_selection| match row_id_selection {
399                RowIdSelection::Full => None,
400                RowIdSelection::Partial(indices) => Some(indices.len()),
401            })
402            .try_fold(0_u64, |acc, next| next.map(|next| next + acc))
403    }
404
405    /// An iterator of row ids
406    ///
407    /// If there are any "full fragment" items then this can't be calculated and None
408    /// is returned
409    pub fn row_ids(&self) -> Option<impl Iterator<Item = RowAddress> + '_> {
410        let inner_iters = self
411            .inner
412            .iter()
413            .filter_map(|(frag_id, row_id_selection)| match row_id_selection {
414                RowIdSelection::Full => None,
415                RowIdSelection::Partial(bitmap) => Some(
416                    bitmap
417                        .iter()
418                        .map(|row_offset| RowAddress::new_from_parts(*frag_id, row_offset)),
419                ),
420            })
421            .collect::<Vec<_>>();
422        if inner_iters.len() != self.inner.len() {
423            None
424        } else {
425            Some(inner_iters.into_iter().flatten())
426        }
427    }
428
429    /// Insert a single value into the set
430    ///
431    /// Returns true if the value was not already in the set.
432    ///
433    /// ```rust
434    /// use lance_core::utils::mask::RowIdTreeMap;
435    ///
436    /// let mut set = RowIdTreeMap::new();
437    /// assert_eq!(set.insert(10), true);
438    /// assert_eq!(set.insert(10), false);
439    /// assert_eq!(set.contains(10), true);
440    /// ```
441    pub fn insert(&mut self, value: u64) -> bool {
442        let fragment = (value >> 32) as u32;
443        let row_addr = value as u32;
444        match self.inner.get_mut(&fragment) {
445            None => {
446                let mut set = RoaringBitmap::new();
447                set.insert(row_addr);
448                self.inner.insert(fragment, RowIdSelection::Partial(set));
449                true
450            }
451            Some(RowIdSelection::Full) => false,
452            Some(RowIdSelection::Partial(set)) => set.insert(row_addr),
453        }
454    }
455
456    /// Insert a range of values into the set
457    pub fn insert_range<R: RangeBounds<u64>>(&mut self, range: R) -> u64 {
458        // Separate the start and end into high and low bits.
459        let (mut start_high, mut start_low) = match range.start_bound() {
460            std::ops::Bound::Included(&start) => ((start >> 32) as u32, start as u32),
461            std::ops::Bound::Excluded(&start) => {
462                let start = start.saturating_add(1);
463                ((start >> 32) as u32, start as u32)
464            }
465            std::ops::Bound::Unbounded => (0, 0),
466        };
467
468        let (end_high, end_low) = match range.end_bound() {
469            std::ops::Bound::Included(&end) => ((end >> 32) as u32, end as u32),
470            std::ops::Bound::Excluded(&end) => {
471                let end = end.saturating_sub(1);
472                ((end >> 32) as u32, end as u32)
473            }
474            std::ops::Bound::Unbounded => (u32::MAX, u32::MAX),
475        };
476
477        let mut count = 0;
478
479        while start_high <= end_high {
480            let start = start_low;
481            let end = if start_high == end_high {
482                end_low
483            } else {
484                u32::MAX
485            };
486            let fragment = start_high;
487            match self.inner.get_mut(&fragment) {
488                None => {
489                    let mut set = RoaringBitmap::new();
490                    count += set.insert_range(start..=end);
491                    self.inner.insert(fragment, RowIdSelection::Partial(set));
492                }
493                Some(RowIdSelection::Full) => {}
494                Some(RowIdSelection::Partial(set)) => {
495                    count += set.insert_range(start..=end);
496                }
497            }
498            start_high += 1;
499            start_low = 0;
500        }
501
502        count
503    }
504
505    /// Add a bitmap for a single fragment
506    pub fn insert_bitmap(&mut self, fragment: u32, bitmap: RoaringBitmap) {
507        self.inner.insert(fragment, RowIdSelection::Partial(bitmap));
508    }
509
510    /// Add a whole fragment to the set
511    pub fn insert_fragment(&mut self, fragment_id: u32) {
512        self.inner.insert(fragment_id, RowIdSelection::Full);
513    }
514
515    pub fn get_fragment_bitmap(&self, fragment_id: u32) -> Option<&RoaringBitmap> {
516        match self.inner.get(&fragment_id) {
517            None => None,
518            Some(RowIdSelection::Full) => None,
519            Some(RowIdSelection::Partial(set)) => Some(set),
520        }
521    }
522
523    /// Returns whether the set contains the given value
524    pub fn contains(&self, value: u64) -> bool {
525        let upper = (value >> 32) as u32;
526        let lower = value as u32;
527        match self.inner.get(&upper) {
528            None => false,
529            Some(RowIdSelection::Full) => true,
530            Some(RowIdSelection::Partial(fragment_set)) => fragment_set.contains(lower),
531        }
532    }
533
534    pub fn remove(&mut self, value: u64) -> bool {
535        let upper = (value >> 32) as u32;
536        let lower = value as u32;
537        match self.inner.get_mut(&upper) {
538            None => false,
539            Some(RowIdSelection::Full) => {
540                let mut set = RoaringBitmap::full();
541                set.remove(lower);
542                self.inner.insert(upper, RowIdSelection::Partial(set));
543                true
544            }
545            Some(RowIdSelection::Partial(lower_set)) => {
546                let removed = lower_set.remove(lower);
547                if lower_set.is_empty() {
548                    self.inner.remove(&upper);
549                }
550                removed
551            }
552        }
553    }
554
555    pub fn retain_fragments(&mut self, frag_ids: impl IntoIterator<Item = u32>) {
556        let frag_id_set = frag_ids.into_iter().collect::<HashSet<_>>();
557        self.inner
558            .retain(|frag_id, _| frag_id_set.contains(frag_id));
559    }
560
561    /// Compute the serialized size of the set.
562    pub fn serialized_size(&self) -> usize {
563        // Starts at 4 because of the u32 num_entries
564        let mut size = 4;
565        for set in self.inner.values() {
566            // Each entry is 8 bytes for the fragment id and the bitmap size
567            size += 8;
568            if let RowIdSelection::Partial(set) = set {
569                size += set.serialized_size();
570            }
571        }
572        size
573    }
574
575    /// Serialize the set into the given buffer
576    ///
577    /// The serialization format is not stable.
578    ///
579    /// The serialization format is:
580    /// * u32: num_entries
581    ///
582    /// for each entry:
583    ///   * u32: fragment_id
584    ///   * u32: bitmap size
585    ///   * \[u8\]: bitmap
586    ///
587    /// If bitmap size is zero then the entire fragment is selected.
588    pub fn serialize_into<W: Write>(&self, mut writer: W) -> Result<()> {
589        writer.write_u32::<byteorder::LittleEndian>(self.inner.len() as u32)?;
590        for (fragment, set) in &self.inner {
591            writer.write_u32::<byteorder::LittleEndian>(*fragment)?;
592            if let RowIdSelection::Partial(set) = set {
593                writer.write_u32::<byteorder::LittleEndian>(set.serialized_size() as u32)?;
594                set.serialize_into(&mut writer)?;
595            } else {
596                writer.write_u32::<byteorder::LittleEndian>(0)?;
597            }
598        }
599        Ok(())
600    }
601
602    /// Deserialize the set from the given buffer
603    pub fn deserialize_from<R: Read>(mut reader: R) -> Result<Self> {
604        let num_entries = reader.read_u32::<byteorder::LittleEndian>()?;
605        let mut inner = BTreeMap::new();
606        for _ in 0..num_entries {
607            let fragment = reader.read_u32::<byteorder::LittleEndian>()?;
608            let bitmap_size = reader.read_u32::<byteorder::LittleEndian>()?;
609            if bitmap_size == 0 {
610                inner.insert(fragment, RowIdSelection::Full);
611            } else {
612                let mut buffer = vec![0; bitmap_size as usize];
613                reader.read_exact(&mut buffer)?;
614                let set = RoaringBitmap::deserialize_from(&buffer[..])?;
615                inner.insert(fragment, RowIdSelection::Partial(set));
616            }
617        }
618        Ok(Self { inner })
619    }
620
621    pub fn union_all(maps: &[&Self]) -> Self {
622        let mut new_map = BTreeMap::new();
623
624        for map in maps {
625            for (fragment, selection) in &map.inner {
626                new_map
627                    .entry(fragment)
628                    // I hate this allocation, but I can't think of a better way
629                    .or_insert_with(|| Vec::with_capacity(maps.len()))
630                    .push(selection);
631            }
632        }
633
634        let new_map = new_map
635            .into_iter()
636            .map(|(&fragment, selections)| (fragment, RowIdSelection::union_all(&selections)))
637            .collect();
638
639        Self { inner: new_map }
640    }
641
642    /// Apply a mask to the row ids
643    ///
644    /// If there is an allow list then this will intersect the set with the allow list
645    /// If there is a block list then this will subtract the block list from the set
646    pub fn mask(&mut self, mask: &RowIdMask) {
647        if let Some(allow_list) = &mask.allow_list {
648            *self &= allow_list;
649        }
650        if let Some(block_list) = &mask.block_list {
651            *self -= block_list;
652        }
653    }
654
655    /// Convert the set into an iterator of row ids
656    ///
657    /// # Safety
658    ///
659    /// This is unsafe because if any of the inner RowIdSelection elements
660    /// is not a Partial then the iterator will panic because we don't know
661    /// the size of the bitmap.
662    pub unsafe fn into_id_iter(self) -> impl Iterator<Item = u64> {
663        self.inner
664            .into_iter()
665            .flat_map(|(fragment, selection)| match selection {
666                RowIdSelection::Full => panic!("Size of full fragment is unknown"),
667                RowIdSelection::Partial(bitmap) => bitmap.into_iter().map(move |val| {
668                    let fragment = fragment as u64;
669                    let row_offset = val as u64;
670                    (fragment << 32) | row_offset
671                }),
672            })
673    }
674}
675
676impl std::ops::BitOr<Self> for RowIdTreeMap {
677    type Output = Self;
678
679    fn bitor(mut self, rhs: Self) -> Self::Output {
680        self |= rhs;
681        self
682    }
683}
684
685impl std::ops::BitOrAssign<Self> for RowIdTreeMap {
686    fn bitor_assign(&mut self, rhs: Self) {
687        for (fragment, rhs_set) in &rhs.inner {
688            let lhs_set = self.inner.get_mut(fragment);
689            if let Some(lhs_set) = lhs_set {
690                match lhs_set {
691                    RowIdSelection::Full => {
692                        // If the fragment is already selected then there is nothing to do
693                    }
694                    RowIdSelection::Partial(lhs_bitmap) => match rhs_set {
695                        RowIdSelection::Full => {
696                            *lhs_set = RowIdSelection::Full;
697                        }
698                        RowIdSelection::Partial(rhs_set) => {
699                            *lhs_bitmap |= rhs_set;
700                        }
701                    },
702                }
703            } else {
704                self.inner.insert(*fragment, rhs_set.clone());
705            }
706        }
707    }
708}
709
710impl std::ops::BitAnd<Self> for RowIdTreeMap {
711    type Output = Self;
712
713    fn bitand(mut self, rhs: Self) -> Self::Output {
714        self &= &rhs;
715        self
716    }
717}
718
719impl std::ops::BitAndAssign<&Self> for RowIdTreeMap {
720    fn bitand_assign(&mut self, rhs: &Self) {
721        // Remove fragment that aren't on the RHS
722        self.inner
723            .retain(|fragment, _| rhs.inner.contains_key(fragment));
724
725        // For fragments that are on the RHS, intersect the bitmaps
726        for (fragment, mut lhs_set) in &mut self.inner {
727            match (&mut lhs_set, rhs.inner.get(fragment)) {
728                (_, None) => {} // Already handled by retain
729                (_, Some(RowIdSelection::Full)) => {
730                    // Everything selected on RHS, so can leave LHS untouched.
731                }
732                (RowIdSelection::Partial(lhs_set), Some(RowIdSelection::Partial(rhs_set))) => {
733                    *lhs_set &= rhs_set;
734                }
735                (RowIdSelection::Full, Some(RowIdSelection::Partial(rhs_set))) => {
736                    *lhs_set = RowIdSelection::Partial(rhs_set.clone());
737                }
738            }
739        }
740        // Some bitmaps might now be empty. If they are, we should remove them.
741        self.inner.retain(|_, set| match set {
742            RowIdSelection::Partial(set) => !set.is_empty(),
743            RowIdSelection::Full => true,
744        });
745    }
746}
747
748impl std::ops::SubAssign<&Self> for RowIdTreeMap {
749    fn sub_assign(&mut self, rhs: &Self) {
750        for (fragment, rhs_set) in &rhs.inner {
751            match self.inner.get_mut(fragment) {
752                None => {}
753                Some(RowIdSelection::Full) => {
754                    // If the fragment is already selected then there is nothing to do
755                    match rhs_set {
756                        RowIdSelection::Full => {
757                            self.inner.remove(fragment);
758                        }
759                        RowIdSelection::Partial(rhs_set) => {
760                            // This generally won't be hit.
761                            let mut set = RoaringBitmap::full();
762                            set -= rhs_set;
763                            self.inner.insert(*fragment, RowIdSelection::Partial(set));
764                        }
765                    }
766                }
767                Some(RowIdSelection::Partial(lhs_set)) => match rhs_set {
768                    RowIdSelection::Full => {
769                        self.inner.remove(fragment);
770                    }
771                    RowIdSelection::Partial(rhs_set) => {
772                        *lhs_set -= rhs_set;
773                        if lhs_set.is_empty() {
774                            self.inner.remove(fragment);
775                        }
776                    }
777                },
778            }
779        }
780    }
781}
782
783impl FromIterator<u64> for RowIdTreeMap {
784    fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> Self {
785        let mut inner = BTreeMap::new();
786        for row_id in iter {
787            let upper = (row_id >> 32) as u32;
788            let lower = row_id as u32;
789            match inner.get_mut(&upper) {
790                None => {
791                    let mut set = RoaringBitmap::new();
792                    set.insert(lower);
793                    inner.insert(upper, RowIdSelection::Partial(set));
794                }
795                Some(RowIdSelection::Full) => {
796                    // If the fragment is already selected then there is nothing to do
797                }
798                Some(RowIdSelection::Partial(set)) => {
799                    set.insert(lower);
800                }
801            }
802        }
803        Self { inner }
804    }
805}
806
807impl<'a> FromIterator<&'a u64> for RowIdTreeMap {
808    fn from_iter<T: IntoIterator<Item = &'a u64>>(iter: T) -> Self {
809        Self::from_iter(iter.into_iter().copied())
810    }
811}
812
813impl From<Range<u64>> for RowIdTreeMap {
814    fn from(range: Range<u64>) -> Self {
815        let mut map = Self::default();
816        map.insert_range(range);
817        map
818    }
819}
820
821impl From<RoaringTreemap> for RowIdTreeMap {
822    fn from(roaring: RoaringTreemap) -> Self {
823        let mut inner = BTreeMap::new();
824        for (fragment, set) in roaring.bitmaps() {
825            inner.insert(fragment, RowIdSelection::Partial(set.clone()));
826        }
827        Self { inner }
828    }
829}
830
831impl Extend<u64> for RowIdTreeMap {
832    fn extend<T: IntoIterator<Item = u64>>(&mut self, iter: T) {
833        for row_id in iter {
834            let upper = (row_id >> 32) as u32;
835            let lower = row_id as u32;
836            match self.inner.get_mut(&upper) {
837                None => {
838                    let mut set = RoaringBitmap::new();
839                    set.insert(lower);
840                    self.inner.insert(upper, RowIdSelection::Partial(set));
841                }
842                Some(RowIdSelection::Full) => {
843                    // If the fragment is already selected then there is nothing to do
844                }
845                Some(RowIdSelection::Partial(set)) => {
846                    set.insert(lower);
847                }
848            }
849        }
850    }
851}
852
853impl<'a> Extend<&'a u64> for RowIdTreeMap {
854    fn extend<T: IntoIterator<Item = &'a u64>>(&mut self, iter: T) {
855        self.extend(iter.into_iter().copied())
856    }
857}
858
859// Extending with RowIdTreeMap is basically a cumulative set union
860impl Extend<Self> for RowIdTreeMap {
861    fn extend<T: IntoIterator<Item = Self>>(&mut self, iter: T) {
862        for other in iter {
863            for (fragment, set) in other.inner {
864                match self.inner.get_mut(&fragment) {
865                    None => {
866                        self.inner.insert(fragment, set);
867                    }
868                    Some(RowIdSelection::Full) => {
869                        // If the fragment is already selected then there is nothing to do
870                    }
871                    Some(RowIdSelection::Partial(lhs_set)) => match set {
872                        RowIdSelection::Full => {
873                            self.inner.insert(fragment, RowIdSelection::Full);
874                        }
875                        RowIdSelection::Partial(rhs_set) => {
876                            *lhs_set |= rhs_set;
877                        }
878                    },
879                }
880            }
881        }
882    }
883}
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888    use proptest::prop_assert_eq;
889
890    #[test]
891    fn test_ops() {
892        let mask = RowIdMask::default();
893        assert!(mask.selected(1));
894        assert!(mask.selected(5));
895        let block_list = mask.also_block(RowIdTreeMap::from_iter(&[0, 5, 15]));
896        assert!(block_list.selected(1));
897        assert!(!block_list.selected(5));
898        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[0, 2, 5]));
899        assert!(!allow_list.selected(1));
900        assert!(allow_list.selected(5));
901        let combined = block_list & allow_list;
902        assert!(combined.selected(2));
903        assert!(!combined.selected(0));
904        assert!(!combined.selected(5));
905        let other = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
906        let combined = combined | other;
907        assert!(combined.selected(2));
908        assert!(combined.selected(3));
909        assert!(!combined.selected(0));
910        assert!(!combined.selected(5));
911
912        let block_list = RowIdMask::from_block(RowIdTreeMap::from_iter(&[0]));
913        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
914        let combined = block_list | allow_list;
915        assert!(combined.selected(1));
916    }
917
918    #[test]
919    fn test_map_insert_range() {
920        let ranges = &[
921            (0..10),
922            (40..500),
923            ((u32::MAX as u64 - 10)..(u32::MAX as u64 + 20)),
924        ];
925
926        for range in ranges {
927            let mut mask = RowIdTreeMap::default();
928
929            let count = mask.insert_range(range.clone());
930            let expected = range.end - range.start;
931            assert_eq!(count, expected);
932
933            let count = mask.insert_range(range.clone());
934            assert_eq!(count, 0);
935
936            let new_range = range.start + 5..range.end + 5;
937            let count = mask.insert_range(new_range.clone());
938            assert_eq!(count, 5);
939        }
940
941        let mut mask = RowIdTreeMap::default();
942        let count = mask.insert_range(..10);
943        assert_eq!(count, 10);
944        assert!(mask.contains(0));
945
946        let count = mask.insert_range(20..=24);
947        assert_eq!(count, 5);
948
949        mask.insert_fragment(0);
950        let count = mask.insert_range(100..200);
951        assert_eq!(count, 0);
952    }
953
954    #[test]
955    fn test_map_remove() {
956        let mut mask = RowIdTreeMap::default();
957
958        assert!(!mask.remove(20));
959
960        mask.insert(20);
961        assert!(mask.contains(20));
962        assert!(mask.remove(20));
963        assert!(!mask.contains(20));
964
965        mask.insert_range(10..=20);
966        assert!(mask.contains(15));
967        assert!(mask.remove(15));
968        assert!(!mask.contains(15));
969
970        // We don't test removing from a full fragment, because that would take
971        // a lot of memory.
972    }
973
974    proptest::proptest! {
975        #[test]
976        fn test_map_serialization_roundtrip(
977            values in proptest::collection::vec(
978                (0..u32::MAX, proptest::option::of(proptest::collection::vec(0..u32::MAX, 0..1000))),
979                0..10
980            )
981        ) {
982            let mut mask = RowIdTreeMap::default();
983            for (fragment, rows) in values {
984                if let Some(rows) = rows {
985                    let bitmap = RoaringBitmap::from_iter(rows);
986                    mask.insert_bitmap(fragment, bitmap);
987                } else {
988                    mask.insert_fragment(fragment);
989                }
990            }
991
992            let mut data = Vec::new();
993            mask.serialize_into(&mut data).unwrap();
994            let deserialized = RowIdTreeMap::deserialize_from(data.as_slice()).unwrap();
995            prop_assert_eq!(mask, deserialized);
996        }
997
998        #[test]
999        fn test_map_intersect(
1000            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1001            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1002            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1003            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1004        ) {
1005            let mut left = RowIdTreeMap::default();
1006            for fragment in left_full_fragments.clone() {
1007                left.insert_fragment(fragment);
1008            }
1009            left.extend(left_rows.iter().copied());
1010
1011            let mut right = RowIdTreeMap::default();
1012            for fragment in right_full_fragments.clone() {
1013                right.insert_fragment(fragment);
1014            }
1015            right.extend(right_rows.iter().copied());
1016
1017            let mut expected = RowIdTreeMap::default();
1018            for fragment in &left_full_fragments {
1019                if right_full_fragments.contains(fragment) {
1020                    expected.insert_fragment(*fragment);
1021                }
1022            }
1023
1024            let left_in_right = left_rows.iter().filter(|row| {
1025                right_rows.contains(row)
1026                    || right_full_fragments.contains(&((*row >> 32) as u32))
1027            });
1028            expected.extend(left_in_right);
1029            let right_in_left = right_rows.iter().filter(|row| {
1030                left_rows.contains(row)
1031                    || left_full_fragments.contains(&((*row >> 32) as u32))
1032            });
1033            expected.extend(right_in_left);
1034
1035            let actual = left & right;
1036            prop_assert_eq!(expected, actual);
1037        }
1038
1039        #[test]
1040        fn test_map_union(
1041            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1042            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1043            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1044            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1045        ) {
1046            let mut left = RowIdTreeMap::default();
1047            for fragment in left_full_fragments.clone() {
1048                left.insert_fragment(fragment);
1049            }
1050            left.extend(left_rows.iter().copied());
1051
1052            let mut right = RowIdTreeMap::default();
1053            for fragment in right_full_fragments.clone() {
1054                right.insert_fragment(fragment);
1055            }
1056            right.extend(right_rows.iter().copied());
1057
1058            let mut expected = RowIdTreeMap::default();
1059            for fragment in left_full_fragments {
1060                expected.insert_fragment(fragment);
1061            }
1062            for fragment in right_full_fragments {
1063                expected.insert_fragment(fragment);
1064            }
1065
1066            let combined_rows = left_rows.iter().chain(right_rows.iter());
1067            expected.extend(combined_rows);
1068
1069            let actual = left | right;
1070            for actual_key_val in &actual.inner {
1071                proptest::prop_assert!(expected.inner.contains_key(actual_key_val.0));
1072                let expected_val = expected.inner.get(actual_key_val.0).unwrap();
1073                prop_assert_eq!(
1074                    actual_key_val.1,
1075                    expected_val,
1076                    "error on key {}",
1077                    actual_key_val.0
1078                );
1079            }
1080            prop_assert_eq!(expected, actual);
1081        }
1082
1083        #[test]
1084        fn test_map_subassign_rows(
1085            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1086            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1087            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1088        ) {
1089            let mut left = RowIdTreeMap::default();
1090            for fragment in left_full_fragments {
1091                left.insert_fragment(fragment);
1092            }
1093            left.extend(left_rows.iter().copied());
1094
1095            let mut right = RowIdTreeMap::default();
1096            right.extend(right_rows.iter().copied());
1097
1098            let mut expected = left.clone();
1099            for row in right_rows {
1100                expected.remove(row);
1101            }
1102
1103            left -= &right;
1104            prop_assert_eq!(expected, left);
1105        }
1106
1107        #[test]
1108        fn test_map_subassign_frags(
1109            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1110            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1111            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1112        ) {
1113            let mut left = RowIdTreeMap::default();
1114            for fragment in left_full_fragments {
1115                left.insert_fragment(fragment);
1116            }
1117            left.extend(left_rows.iter().copied());
1118
1119            let mut right = RowIdTreeMap::default();
1120            for fragment in right_full_fragments.clone() {
1121                right.insert_fragment(fragment);
1122            }
1123
1124            let mut expected = left.clone();
1125            for fragment in right_full_fragments {
1126                expected.inner.remove(&fragment);
1127            }
1128
1129            left -= &right;
1130            prop_assert_eq!(expected, left);
1131        }
1132
1133    }
1134
1135    #[test]
1136    fn test_iter_ids() {
1137        let mut mask = RowIdMask::default();
1138        assert!(mask.iter_ids().is_none());
1139
1140        // Test with just an allow list
1141        let mut allow_list = RowIdTreeMap::default();
1142        allow_list.extend([1, 5, 10].iter().copied());
1143        mask.allow_list = Some(allow_list);
1144
1145        let ids: Vec<_> = mask.iter_ids().unwrap().collect();
1146        assert_eq!(
1147            ids,
1148            vec![
1149                RowAddress::new_from_parts(0, 1),
1150                RowAddress::new_from_parts(0, 5),
1151                RowAddress::new_from_parts(0, 10)
1152            ]
1153        );
1154
1155        // Test with both allow list and block list
1156        let mut block_list = RowIdTreeMap::default();
1157        block_list.extend([5].iter().copied());
1158        mask.block_list = Some(block_list);
1159
1160        let ids: Vec<_> = mask.iter_ids().unwrap().collect();
1161        assert_eq!(
1162            ids,
1163            vec![
1164                RowAddress::new_from_parts(0, 1),
1165                RowAddress::new_from_parts(0, 10)
1166            ]
1167        );
1168
1169        // Test with full fragment in block list
1170        let mut block_list = RowIdTreeMap::default();
1171        block_list.insert_fragment(0);
1172        mask.block_list = Some(block_list);
1173        assert!(mask.iter_ids().is_none());
1174
1175        // Test with full fragment in allow list
1176        mask.block_list = None;
1177        let mut allow_list = RowIdTreeMap::default();
1178        allow_list.insert_fragment(0);
1179        mask.allow_list = Some(allow_list);
1180        assert!(mask.iter_ids().is_none());
1181    }
1182}