lance_core/utils/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashSet;
5use std::io::Write;
6use std::iter;
7use std::ops::{Range, RangeBounds};
8use std::{collections::BTreeMap, io::Read};
9
10use arrow_array::{Array, BinaryArray, GenericBinaryArray};
11use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer};
12use byteorder::{ReadBytesExt, WriteBytesExt};
13use deepsize::DeepSizeOf;
14use roaring::{MultiOps, RoaringBitmap, RoaringTreemap};
15
16use crate::Result;
17
18use super::address::RowAddress;
19
20/// A row id mask to select or deselect particular row ids
21///
22/// If both the allow_list and the block_list are Some then the only selected
23/// row ids are those that are in the allow_list but not in the block_list
24/// (the block_list takes precedence)
25///
26/// If both the allow_list and the block_list are None (the default) then
27/// all row ids are selected
28#[derive(Clone, Debug, Default, DeepSizeOf)]
29pub struct RowIdMask {
30    /// If Some then only these row ids are selected
31    pub allow_list: Option<RowIdTreeMap>,
32    /// If Some then these row ids are not selected.
33    pub block_list: Option<RowIdTreeMap>,
34}
35
36impl RowIdMask {
37    // Create a mask allowing all rows, this is an alias for [default]
38    pub fn all_rows() -> Self {
39        Self::default()
40    }
41
42    // Create a mask that doesn't allow anything
43    pub fn allow_nothing() -> Self {
44        Self {
45            allow_list: Some(RowIdTreeMap::new()),
46            block_list: None,
47        }
48    }
49
50    // Create a mask from an allow list
51    pub fn from_allowed(allow_list: RowIdTreeMap) -> Self {
52        Self {
53            allow_list: Some(allow_list),
54            block_list: None,
55        }
56    }
57
58    // Create a mask from a block list
59    pub fn from_block(block_list: RowIdTreeMap) -> Self {
60        Self {
61            allow_list: None,
62            block_list: Some(block_list),
63        }
64    }
65
66    // If there is both a block list and an allow list then collapse into just an allow list
67    pub fn normalize(self) -> Self {
68        if let Self {
69            allow_list: Some(mut allow_list),
70            block_list: Some(block_list),
71        } = self
72        {
73            allow_list -= &block_list;
74            Self {
75                allow_list: Some(allow_list),
76                block_list: None,
77            }
78        } else {
79            self
80        }
81    }
82
83    /// True if the row_id is selected by the mask, false otherwise
84    pub fn selected(&self, row_id: u64) -> bool {
85        match (&self.allow_list, &self.block_list) {
86            (None, None) => true,
87            (Some(allow_list), None) => allow_list.contains(row_id),
88            (None, Some(block_list)) => !block_list.contains(row_id),
89            (Some(allow_list), Some(block_list)) => {
90                allow_list.contains(row_id) && !block_list.contains(row_id)
91            }
92        }
93    }
94
95    /// Return the indices of the input row ids that were valid
96    pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
97        let enumerated_ids = row_ids.enumerate();
98        match (&self.block_list, &self.allow_list) {
99            (Some(block_list), Some(allow_list)) => {
100                // Only take rows that are both in the allow list and not in the block list
101                enumerated_ids
102                    .filter(|(_, row_id)| {
103                        !block_list.contains(**row_id) && allow_list.contains(**row_id)
104                    })
105                    .map(|(idx, _)| idx as u64)
106                    .collect()
107            }
108            (Some(block_list), None) => {
109                // Take rows that are not in the block list
110                enumerated_ids
111                    .filter(|(_, row_id)| !block_list.contains(**row_id))
112                    .map(|(idx, _)| idx as u64)
113                    .collect()
114            }
115            (None, Some(allow_list)) => {
116                // Take rows that are in the allow list
117                enumerated_ids
118                    .filter(|(_, row_id)| allow_list.contains(**row_id))
119                    .map(|(idx, _)| idx as u64)
120                    .collect()
121            }
122            (None, None) => {
123                // We should not encounter this case because callers should
124                // check is_empty first.
125                panic!("selected_indices called but prefilter has nothing to filter with")
126            }
127        }
128    }
129
130    /// Also block the given ids
131    pub fn also_block(self, block_list: RowIdTreeMap) -> Self {
132        if block_list.is_empty() {
133            return self;
134        }
135        if let Some(existing) = self.block_list {
136            Self {
137                block_list: Some(existing | block_list),
138                allow_list: self.allow_list,
139            }
140        } else {
141            Self {
142                block_list: Some(block_list),
143                allow_list: self.allow_list,
144            }
145        }
146    }
147
148    /// Also allow the given ids
149    pub fn also_allow(self, allow_list: RowIdTreeMap) -> Self {
150        if let Some(existing) = self.allow_list {
151            Self {
152                block_list: self.block_list,
153                allow_list: Some(existing | allow_list),
154            }
155        } else {
156            Self {
157                block_list: self.block_list,
158                // allow_list = None means "all rows allowed" and so allowing
159                //              more rows is meaningless
160                allow_list: None,
161            }
162        }
163    }
164
165    /// Convert a mask into an arrow array
166    ///
167    /// A row id mask is not very arrow-compatible.  We can't make it a batch with
168    /// two columns because the block list and allow list will have different lengths.  Also,
169    /// there is no Arrow type for compressed bitmaps.
170    ///
171    /// However, we need to shove it into some kind of Arrow container to pass it along the
172    /// datafusion stream.  Perhaps, in the future, we can add row id masks as first class
173    /// types in datafusion, and this can be passed along as a mask / selection vector.
174    ///
175    /// We serialize this as a variable length binary array with two items.  The first item
176    /// is the block list and the second item is the allow list.
177    pub fn into_arrow(&self) -> Result<BinaryArray> {
178        let block_list_length = self
179            .block_list
180            .as_ref()
181            .map(|bl| bl.serialized_size())
182            .unwrap_or(0);
183        let allow_list_length = self
184            .allow_list
185            .as_ref()
186            .map(|al| al.serialized_size())
187            .unwrap_or(0);
188        let lengths = vec![block_list_length, allow_list_length];
189        let offsets = OffsetBuffer::from_lengths(lengths);
190        let mut value_bytes = vec![0; block_list_length + allow_list_length];
191        let mut validity = vec![false, false];
192        if let Some(block_list) = &self.block_list {
193            validity[0] = true;
194            block_list.serialize_into(&mut value_bytes[0..])?;
195        }
196        if let Some(allow_list) = &self.allow_list {
197            validity[1] = true;
198            allow_list.serialize_into(&mut value_bytes[block_list_length..])?;
199        }
200        let values = Buffer::from(value_bytes);
201        let nulls = NullBuffer::from(validity);
202        Ok(BinaryArray::try_new(offsets, values, Some(nulls))?)
203    }
204
205    /// Deserialize a row id mask from Arrow
206    pub fn from_arrow(array: &GenericBinaryArray<i32>) -> Result<Self> {
207        let block_list = if array.is_null(0) {
208            None
209        } else {
210            Some(RowIdTreeMap::deserialize_from(array.value(0)))
211        }
212        .transpose()?;
213
214        let allow_list = if array.is_null(1) {
215            None
216        } else {
217            Some(RowIdTreeMap::deserialize_from(array.value(1)))
218        }
219        .transpose()?;
220        Ok(Self {
221            block_list,
222            allow_list,
223        })
224    }
225
226    /// Return the maximum number of row ids that could be selected by this mask
227    ///
228    /// Will be None if there is no allow list
229    pub fn max_len(&self) -> Option<u64> {
230        if let Some(allow_list) = &self.allow_list {
231            // If there is a block list we could theoretically intersect the two
232            // but it's not clear if that is worth the effort.  Feel free to add later.
233            allow_list.len()
234        } else {
235            None
236        }
237    }
238
239    /// Iterate over the row ids that are selected by the mask
240    ///
241    /// This is only possible if there is an allow list and neither the
242    /// allow list nor the block list contain any "full fragment" blocks.
243    ///
244    /// TODO: We could probably still iterate efficiently even if the block
245    /// list contains "full fragment" blocks but that would require some
246    /// extra logic.
247    pub fn iter_ids(&self) -> Option<Box<dyn Iterator<Item = RowAddress> + '_>> {
248        if let Some(mut allow_iter) = self.allow_list.as_ref().and_then(|list| list.row_ids()) {
249            if let Some(block_list) = &self.block_list {
250                if let Some(block_iter) = block_list.row_ids() {
251                    let mut block_iter = block_iter.peekable();
252                    Some(Box::new(iter::from_fn(move || {
253                        for allow_id in allow_iter.by_ref() {
254                            while let Some(block_id) = block_iter.peek() {
255                                if *block_id >= allow_id {
256                                    break;
257                                }
258                                block_iter.next();
259                            }
260                            if let Some(block_id) = block_iter.peek() {
261                                if *block_id == allow_id {
262                                    continue;
263                                }
264                            }
265                            return Some(allow_id);
266                        }
267                        None
268                    })))
269                } else {
270                    // There is a block list but we can't iterate over it, give up
271                    None
272                }
273            } else {
274                // There is no block list, use the allow list
275                Some(Box::new(allow_iter))
276            }
277        } else {
278            None
279        }
280    }
281}
282
283impl std::ops::Not for RowIdMask {
284    type Output = Self;
285
286    fn not(self) -> Self::Output {
287        Self {
288            block_list: self.allow_list,
289            allow_list: self.block_list,
290        }
291    }
292}
293
294impl std::ops::BitAnd for RowIdMask {
295    type Output = Self;
296
297    fn bitand(self, rhs: Self) -> Self::Output {
298        let block_list = match (self.block_list, rhs.block_list) {
299            (None, None) => None,
300            (Some(lhs), None) => Some(lhs),
301            (None, Some(rhs)) => Some(rhs),
302            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
303        };
304        let allow_list = match (self.allow_list, rhs.allow_list) {
305            (None, None) => None,
306            (Some(lhs), None) => Some(lhs),
307            (None, Some(rhs)) => Some(rhs),
308            (Some(lhs), Some(rhs)) => Some(lhs & rhs),
309        };
310        Self {
311            block_list,
312            allow_list,
313        }
314    }
315}
316
317impl std::ops::BitOr for RowIdMask {
318    type Output = Self;
319
320    fn bitor(self, rhs: Self) -> Self::Output {
321        let this = self.normalize();
322        let rhs = rhs.normalize();
323        let block_list = if let Some(mut self_block_list) = this.block_list {
324            match (&rhs.allow_list, rhs.block_list) {
325                // If RHS is allow all, then our block list disappears
326                (None, None) => None,
327                // If RHS is allow list, remove allowed from our block list
328                (Some(allow_list), None) => {
329                    self_block_list -= allow_list;
330                    Some(self_block_list)
331                }
332                // If RHS is block list, intersect
333                (None, Some(block_list)) => Some(self_block_list & block_list),
334                // We normalized to avoid this path
335                (Some(_), Some(_)) => unreachable!(),
336            }
337        } else if let Some(mut rhs_block_list) = rhs.block_list {
338            if let Some(allow_list) = &this.allow_list {
339                rhs_block_list -= allow_list;
340                Some(rhs_block_list)
341            } else {
342                Some(rhs_block_list)
343            }
344        } else {
345            None
346        };
347
348        let allow_list = match (this.allow_list, rhs.allow_list) {
349            (None, None) => None,
350            // Remember that an allow list of None means "all rows" and
351            // so "all rows" | "some rows" is always "all rows"
352            (Some(_), None) => None,
353            (None, Some(_)) => None,
354            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
355        };
356        Self {
357            block_list,
358            allow_list,
359        }
360    }
361}
362
363/// A collection of row ids.
364///
365/// These row ids may either be stable-style (where they can be an incrementing
366/// u64 sequence) or address style, where they are a fragment id and a row offset.
367/// When address style, this supports setting entire fragments as selected,
368/// without needing to enumerate all the ids in the fragment.
369///
370/// This is similar to a [RoaringTreemap] but it is optimized for the case where
371/// entire fragments are selected or deselected.
372#[derive(Clone, Debug, Default, PartialEq, DeepSizeOf)]
373pub struct RowIdTreeMap {
374    /// The contents of the set. If there is a pair (k, Full) then the entire
375    /// fragment k is selected. If there is a pair (k, Partial(v)) then the
376    /// fragment k has the selected rows in v.
377    inner: BTreeMap<u32, RowIdSelection>,
378}
379
380#[derive(Clone, Debug, PartialEq)]
381enum RowIdSelection {
382    Full,
383    Partial(RoaringBitmap),
384}
385
386impl DeepSizeOf for RowIdSelection {
387    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
388        match self {
389            Self::Full => 0,
390            Self::Partial(bitmap) => bitmap.serialized_size(),
391        }
392    }
393}
394
395impl RowIdSelection {
396    fn union_all(selections: &[&Self]) -> Self {
397        let mut is_full = false;
398
399        let res = Self::Partial(
400            selections
401                .iter()
402                .filter_map(|selection| match selection {
403                    Self::Full => {
404                        is_full = true;
405                        None
406                    }
407                    Self::Partial(bitmap) => Some(bitmap),
408                })
409                .union(),
410        );
411
412        if is_full {
413            Self::Full
414        } else {
415            res
416        }
417    }
418}
419
420impl RowIdTreeMap {
421    /// Create an empty set
422    pub fn new() -> Self {
423        Self::default()
424    }
425
426    pub fn is_empty(&self) -> bool {
427        self.inner.is_empty()
428    }
429
430    /// The number of rows in the map
431    ///
432    /// If there are any "full fragment" items then this is unknown and None is returned
433    pub fn len(&self) -> Option<u64> {
434        self.inner
435            .values()
436            .map(|row_id_selection| match row_id_selection {
437                RowIdSelection::Full => None,
438                RowIdSelection::Partial(indices) => Some(indices.len()),
439            })
440            .try_fold(0_u64, |acc, next| next.map(|next| next + acc))
441    }
442
443    /// An iterator of row ids
444    ///
445    /// If there are any "full fragment" items then this can't be calculated and None
446    /// is returned
447    pub fn row_ids(&self) -> Option<impl Iterator<Item = RowAddress> + '_> {
448        let inner_iters = self
449            .inner
450            .iter()
451            .filter_map(|(frag_id, row_id_selection)| match row_id_selection {
452                RowIdSelection::Full => None,
453                RowIdSelection::Partial(bitmap) => Some(
454                    bitmap
455                        .iter()
456                        .map(|row_offset| RowAddress::new_from_parts(*frag_id, row_offset)),
457                ),
458            })
459            .collect::<Vec<_>>();
460        if inner_iters.len() != self.inner.len() {
461            None
462        } else {
463            Some(inner_iters.into_iter().flatten())
464        }
465    }
466
467    /// Insert a single value into the set
468    ///
469    /// Returns true if the value was not already in the set.
470    ///
471    /// ```rust
472    /// use lance_core::utils::mask::RowIdTreeMap;
473    ///
474    /// let mut set = RowIdTreeMap::new();
475    /// assert_eq!(set.insert(10), true);
476    /// assert_eq!(set.insert(10), false);
477    /// assert_eq!(set.contains(10), true);
478    /// ```
479    pub fn insert(&mut self, value: u64) -> bool {
480        let fragment = (value >> 32) as u32;
481        let row_addr = value as u32;
482        match self.inner.get_mut(&fragment) {
483            None => {
484                let mut set = RoaringBitmap::new();
485                set.insert(row_addr);
486                self.inner.insert(fragment, RowIdSelection::Partial(set));
487                true
488            }
489            Some(RowIdSelection::Full) => false,
490            Some(RowIdSelection::Partial(set)) => set.insert(row_addr),
491        }
492    }
493
494    /// Insert a range of values into the set
495    pub fn insert_range<R: RangeBounds<u64>>(&mut self, range: R) -> u64 {
496        // Separate the start and end into high and low bits.
497        let (mut start_high, mut start_low) = match range.start_bound() {
498            std::ops::Bound::Included(&start) => ((start >> 32) as u32, start as u32),
499            std::ops::Bound::Excluded(&start) => {
500                let start = start.saturating_add(1);
501                ((start >> 32) as u32, start as u32)
502            }
503            std::ops::Bound::Unbounded => (0, 0),
504        };
505
506        let (end_high, end_low) = match range.end_bound() {
507            std::ops::Bound::Included(&end) => ((end >> 32) as u32, end as u32),
508            std::ops::Bound::Excluded(&end) => {
509                let end = end.saturating_sub(1);
510                ((end >> 32) as u32, end as u32)
511            }
512            std::ops::Bound::Unbounded => (u32::MAX, u32::MAX),
513        };
514
515        let mut count = 0;
516
517        while start_high <= end_high {
518            let start = start_low;
519            let end = if start_high == end_high {
520                end_low
521            } else {
522                u32::MAX
523            };
524            let fragment = start_high;
525            match self.inner.get_mut(&fragment) {
526                None => {
527                    let mut set = RoaringBitmap::new();
528                    count += set.insert_range(start..=end);
529                    self.inner.insert(fragment, RowIdSelection::Partial(set));
530                }
531                Some(RowIdSelection::Full) => {}
532                Some(RowIdSelection::Partial(set)) => {
533                    count += set.insert_range(start..=end);
534                }
535            }
536            start_high += 1;
537            start_low = 0;
538        }
539
540        count
541    }
542
543    /// Add a bitmap for a single fragment
544    pub fn insert_bitmap(&mut self, fragment: u32, bitmap: RoaringBitmap) {
545        self.inner.insert(fragment, RowIdSelection::Partial(bitmap));
546    }
547
548    /// Add a whole fragment to the set
549    pub fn insert_fragment(&mut self, fragment_id: u32) {
550        self.inner.insert(fragment_id, RowIdSelection::Full);
551    }
552
553    pub fn get_fragment_bitmap(&self, fragment_id: u32) -> Option<&RoaringBitmap> {
554        match self.inner.get(&fragment_id) {
555            None => None,
556            Some(RowIdSelection::Full) => None,
557            Some(RowIdSelection::Partial(set)) => Some(set),
558        }
559    }
560
561    /// Returns whether the set contains the given value
562    pub fn contains(&self, value: u64) -> bool {
563        let upper = (value >> 32) as u32;
564        let lower = value as u32;
565        match self.inner.get(&upper) {
566            None => false,
567            Some(RowIdSelection::Full) => true,
568            Some(RowIdSelection::Partial(fragment_set)) => fragment_set.contains(lower),
569        }
570    }
571
572    pub fn remove(&mut self, value: u64) -> bool {
573        let upper = (value >> 32) as u32;
574        let lower = value as u32;
575        match self.inner.get_mut(&upper) {
576            None => false,
577            Some(RowIdSelection::Full) => {
578                let mut set = RoaringBitmap::full();
579                set.remove(lower);
580                self.inner.insert(upper, RowIdSelection::Partial(set));
581                true
582            }
583            Some(RowIdSelection::Partial(lower_set)) => {
584                let removed = lower_set.remove(lower);
585                if lower_set.is_empty() {
586                    self.inner.remove(&upper);
587                }
588                removed
589            }
590        }
591    }
592
593    pub fn retain_fragments(&mut self, frag_ids: impl IntoIterator<Item = u32>) {
594        let frag_id_set = frag_ids.into_iter().collect::<HashSet<_>>();
595        self.inner
596            .retain(|frag_id, _| frag_id_set.contains(frag_id));
597    }
598
599    /// Compute the serialized size of the set.
600    pub fn serialized_size(&self) -> usize {
601        // Starts at 4 because of the u32 num_entries
602        let mut size = 4;
603        for set in self.inner.values() {
604            // Each entry is 8 bytes for the fragment id and the bitmap size
605            size += 8;
606            if let RowIdSelection::Partial(set) = set {
607                size += set.serialized_size();
608            }
609        }
610        size
611    }
612
613    /// Serialize the set into the given buffer
614    ///
615    /// The serialization format is not stable.
616    ///
617    /// The serialization format is:
618    /// * u32: num_entries
619    ///
620    /// for each entry:
621    ///   * u32: fragment_id
622    ///   * u32: bitmap size
623    ///   * \[u8\]: bitmap
624    ///
625    /// If bitmap size is zero then the entire fragment is selected.
626    pub fn serialize_into<W: Write>(&self, mut writer: W) -> Result<()> {
627        writer.write_u32::<byteorder::LittleEndian>(self.inner.len() as u32)?;
628        for (fragment, set) in &self.inner {
629            writer.write_u32::<byteorder::LittleEndian>(*fragment)?;
630            if let RowIdSelection::Partial(set) = set {
631                writer.write_u32::<byteorder::LittleEndian>(set.serialized_size() as u32)?;
632                set.serialize_into(&mut writer)?;
633            } else {
634                writer.write_u32::<byteorder::LittleEndian>(0)?;
635            }
636        }
637        Ok(())
638    }
639
640    /// Deserialize the set from the given buffer
641    pub fn deserialize_from<R: Read>(mut reader: R) -> Result<Self> {
642        let num_entries = reader.read_u32::<byteorder::LittleEndian>()?;
643        let mut inner = BTreeMap::new();
644        for _ in 0..num_entries {
645            let fragment = reader.read_u32::<byteorder::LittleEndian>()?;
646            let bitmap_size = reader.read_u32::<byteorder::LittleEndian>()?;
647            if bitmap_size == 0 {
648                inner.insert(fragment, RowIdSelection::Full);
649            } else {
650                let mut buffer = vec![0; bitmap_size as usize];
651                reader.read_exact(&mut buffer)?;
652                let set = RoaringBitmap::deserialize_from(&buffer[..])?;
653                inner.insert(fragment, RowIdSelection::Partial(set));
654            }
655        }
656        Ok(Self { inner })
657    }
658
659    pub fn union_all(maps: &[&Self]) -> Self {
660        let mut new_map = BTreeMap::new();
661
662        for map in maps {
663            for (fragment, selection) in &map.inner {
664                new_map
665                    .entry(fragment)
666                    // I hate this allocation, but I can't think of a better way
667                    .or_insert_with(|| Vec::with_capacity(maps.len()))
668                    .push(selection);
669            }
670        }
671
672        let new_map = new_map
673            .into_iter()
674            .map(|(&fragment, selections)| (fragment, RowIdSelection::union_all(&selections)))
675            .collect();
676
677        Self { inner: new_map }
678    }
679
680    /// Apply a mask to the row ids
681    ///
682    /// If there is an allow list then this will intersect the set with the allow list
683    /// If there is a block list then this will subtract the block list from the set
684    pub fn mask(&mut self, mask: &RowIdMask) {
685        if let Some(allow_list) = &mask.allow_list {
686            *self &= allow_list;
687        }
688        if let Some(block_list) = &mask.block_list {
689            *self -= block_list;
690        }
691    }
692
693    /// Convert the set into an iterator of row addrs
694    ///
695    /// # Safety
696    ///
697    /// This is unsafe because if any of the inner RowIdSelection elements
698    /// is not a Partial then the iterator will panic because we don't know
699    /// the size of the bitmap.
700    pub unsafe fn into_addr_iter(self) -> impl Iterator<Item = u64> {
701        self.inner
702            .into_iter()
703            .flat_map(|(fragment, selection)| match selection {
704                RowIdSelection::Full => panic!("Size of full fragment is unknown"),
705                RowIdSelection::Partial(bitmap) => bitmap.into_iter().map(move |val| {
706                    let fragment = fragment as u64;
707                    let row_offset = val as u64;
708                    (fragment << 32) | row_offset
709                }),
710            })
711    }
712}
713
714impl std::ops::BitOr<Self> for RowIdTreeMap {
715    type Output = Self;
716
717    fn bitor(mut self, rhs: Self) -> Self::Output {
718        self |= rhs;
719        self
720    }
721}
722
723impl std::ops::BitOrAssign<Self> for RowIdTreeMap {
724    fn bitor_assign(&mut self, rhs: Self) {
725        for (fragment, rhs_set) in &rhs.inner {
726            let lhs_set = self.inner.get_mut(fragment);
727            if let Some(lhs_set) = lhs_set {
728                match lhs_set {
729                    RowIdSelection::Full => {
730                        // If the fragment is already selected then there is nothing to do
731                    }
732                    RowIdSelection::Partial(lhs_bitmap) => match rhs_set {
733                        RowIdSelection::Full => {
734                            *lhs_set = RowIdSelection::Full;
735                        }
736                        RowIdSelection::Partial(rhs_set) => {
737                            *lhs_bitmap |= rhs_set;
738                        }
739                    },
740                }
741            } else {
742                self.inner.insert(*fragment, rhs_set.clone());
743            }
744        }
745    }
746}
747
748impl std::ops::BitAnd<Self> for RowIdTreeMap {
749    type Output = Self;
750
751    fn bitand(mut self, rhs: Self) -> Self::Output {
752        self &= &rhs;
753        self
754    }
755}
756
757impl std::ops::BitAndAssign<&Self> for RowIdTreeMap {
758    fn bitand_assign(&mut self, rhs: &Self) {
759        // Remove fragment that aren't on the RHS
760        self.inner
761            .retain(|fragment, _| rhs.inner.contains_key(fragment));
762
763        // For fragments that are on the RHS, intersect the bitmaps
764        for (fragment, mut lhs_set) in &mut self.inner {
765            match (&mut lhs_set, rhs.inner.get(fragment)) {
766                (_, None) => {} // Already handled by retain
767                (_, Some(RowIdSelection::Full)) => {
768                    // Everything selected on RHS, so can leave LHS untouched.
769                }
770                (RowIdSelection::Partial(lhs_set), Some(RowIdSelection::Partial(rhs_set))) => {
771                    *lhs_set &= rhs_set;
772                }
773                (RowIdSelection::Full, Some(RowIdSelection::Partial(rhs_set))) => {
774                    *lhs_set = RowIdSelection::Partial(rhs_set.clone());
775                }
776            }
777        }
778        // Some bitmaps might now be empty. If they are, we should remove them.
779        self.inner.retain(|_, set| match set {
780            RowIdSelection::Partial(set) => !set.is_empty(),
781            RowIdSelection::Full => true,
782        });
783    }
784}
785
786impl std::ops::Sub<Self> for RowIdTreeMap {
787    type Output = Self;
788
789    fn sub(mut self, rhs: Self) -> Self {
790        self -= &rhs;
791        self
792    }
793}
794
795impl std::ops::SubAssign<&Self> for RowIdTreeMap {
796    fn sub_assign(&mut self, rhs: &Self) {
797        for (fragment, rhs_set) in &rhs.inner {
798            match self.inner.get_mut(fragment) {
799                None => {}
800                Some(RowIdSelection::Full) => {
801                    // If the fragment is already selected then there is nothing to do
802                    match rhs_set {
803                        RowIdSelection::Full => {
804                            self.inner.remove(fragment);
805                        }
806                        RowIdSelection::Partial(rhs_set) => {
807                            // This generally won't be hit.
808                            let mut set = RoaringBitmap::full();
809                            set -= rhs_set;
810                            self.inner.insert(*fragment, RowIdSelection::Partial(set));
811                        }
812                    }
813                }
814                Some(RowIdSelection::Partial(lhs_set)) => match rhs_set {
815                    RowIdSelection::Full => {
816                        self.inner.remove(fragment);
817                    }
818                    RowIdSelection::Partial(rhs_set) => {
819                        *lhs_set -= rhs_set;
820                        if lhs_set.is_empty() {
821                            self.inner.remove(fragment);
822                        }
823                    }
824                },
825            }
826        }
827    }
828}
829
830impl FromIterator<u64> for RowIdTreeMap {
831    fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> Self {
832        let mut inner = BTreeMap::new();
833        for row_id in iter {
834            let upper = (row_id >> 32) as u32;
835            let lower = row_id as u32;
836            match inner.get_mut(&upper) {
837                None => {
838                    let mut set = RoaringBitmap::new();
839                    set.insert(lower);
840                    inner.insert(upper, RowIdSelection::Partial(set));
841                }
842                Some(RowIdSelection::Full) => {
843                    // If the fragment is already selected then there is nothing to do
844                }
845                Some(RowIdSelection::Partial(set)) => {
846                    set.insert(lower);
847                }
848            }
849        }
850        Self { inner }
851    }
852}
853
854impl<'a> FromIterator<&'a u64> for RowIdTreeMap {
855    fn from_iter<T: IntoIterator<Item = &'a u64>>(iter: T) -> Self {
856        Self::from_iter(iter.into_iter().copied())
857    }
858}
859
860impl From<Range<u64>> for RowIdTreeMap {
861    fn from(range: Range<u64>) -> Self {
862        let mut map = Self::default();
863        map.insert_range(range);
864        map
865    }
866}
867
868impl From<RoaringTreemap> for RowIdTreeMap {
869    fn from(roaring: RoaringTreemap) -> Self {
870        let mut inner = BTreeMap::new();
871        for (fragment, set) in roaring.bitmaps() {
872            inner.insert(fragment, RowIdSelection::Partial(set.clone()));
873        }
874        Self { inner }
875    }
876}
877
878impl Extend<u64> for RowIdTreeMap {
879    fn extend<T: IntoIterator<Item = u64>>(&mut self, iter: T) {
880        for row_id in iter {
881            let upper = (row_id >> 32) as u32;
882            let lower = row_id as u32;
883            match self.inner.get_mut(&upper) {
884                None => {
885                    let mut set = RoaringBitmap::new();
886                    set.insert(lower);
887                    self.inner.insert(upper, RowIdSelection::Partial(set));
888                }
889                Some(RowIdSelection::Full) => {
890                    // If the fragment is already selected then there is nothing to do
891                }
892                Some(RowIdSelection::Partial(set)) => {
893                    set.insert(lower);
894                }
895            }
896        }
897    }
898}
899
900impl<'a> Extend<&'a u64> for RowIdTreeMap {
901    fn extend<T: IntoIterator<Item = &'a u64>>(&mut self, iter: T) {
902        self.extend(iter.into_iter().copied())
903    }
904}
905
906// Extending with RowIdTreeMap is basically a cumulative set union
907impl Extend<Self> for RowIdTreeMap {
908    fn extend<T: IntoIterator<Item = Self>>(&mut self, iter: T) {
909        for other in iter {
910            for (fragment, set) in other.inner {
911                match self.inner.get_mut(&fragment) {
912                    None => {
913                        self.inner.insert(fragment, set);
914                    }
915                    Some(RowIdSelection::Full) => {
916                        // If the fragment is already selected then there is nothing to do
917                    }
918                    Some(RowIdSelection::Partial(lhs_set)) => match set {
919                        RowIdSelection::Full => {
920                            self.inner.insert(fragment, RowIdSelection::Full);
921                        }
922                        RowIdSelection::Partial(rhs_set) => {
923                            *lhs_set |= rhs_set;
924                        }
925                    },
926                }
927            }
928        }
929    }
930}
931
932#[cfg(test)]
933mod tests {
934    use super::*;
935    use proptest::prop_assert_eq;
936
937    #[test]
938    fn test_ops() {
939        let mask = RowIdMask::default();
940        assert!(mask.selected(1));
941        assert!(mask.selected(5));
942        let block_list = mask.also_block(RowIdTreeMap::from_iter(&[0, 5, 15]));
943        assert!(block_list.selected(1));
944        assert!(!block_list.selected(5));
945        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[0, 2, 5]));
946        assert!(!allow_list.selected(1));
947        assert!(allow_list.selected(5));
948        let combined = block_list & allow_list;
949        assert!(combined.selected(2));
950        assert!(!combined.selected(0));
951        assert!(!combined.selected(5));
952        let other = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
953        let combined = combined | other;
954        assert!(combined.selected(2));
955        assert!(combined.selected(3));
956        assert!(!combined.selected(0));
957        assert!(!combined.selected(5));
958
959        let block_list = RowIdMask::from_block(RowIdTreeMap::from_iter(&[0]));
960        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
961        let combined = block_list | allow_list;
962        assert!(combined.selected(1));
963    }
964
965    #[test]
966    fn test_logical_or() {
967        let allow1 = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[5, 6, 7, 8, 9]));
968        let block1 = RowIdMask::from_block(RowIdTreeMap::from_iter(&[5, 6]));
969        let mixed1 = allow1
970            .clone()
971            .also_block(block1.block_list.as_ref().unwrap().clone());
972        let allow2 = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[2, 3, 4, 5, 6, 7, 8]));
973        let block2 = RowIdMask::from_block(RowIdTreeMap::from_iter(&[4, 5]));
974        let mixed2 = allow2
975            .clone()
976            .also_block(block2.block_list.as_ref().unwrap().clone());
977
978        fn check(lhs: &RowIdMask, rhs: &RowIdMask, expected: &[u64]) {
979            for mask in [lhs.clone() | rhs.clone(), rhs.clone() | lhs.clone()] {
980                let values = (0..10)
981                    .filter(|val| mask.selected(*val))
982                    .collect::<Vec<_>>();
983                assert_eq!(&values, expected);
984            }
985        }
986
987        check(&allow1, &allow1, &[5, 6, 7, 8, 9]);
988        check(&block1, &block1, &[0, 1, 2, 3, 4, 7, 8, 9]);
989        check(&mixed1, &mixed1, &[7, 8, 9]);
990        check(&allow2, &allow2, &[2, 3, 4, 5, 6, 7, 8]);
991        check(&block2, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
992        check(&mixed2, &mixed2, &[2, 3, 6, 7, 8]);
993
994        check(&allow1, &block1, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
995        check(&allow1, &mixed1, &[5, 6, 7, 8, 9]);
996        check(&allow1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
997        check(&allow1, &block2, &[0, 1, 2, 3, 5, 6, 7, 8, 9]);
998        check(&allow1, &mixed2, &[2, 3, 5, 6, 7, 8, 9]);
999        check(&block1, &mixed1, &[0, 1, 2, 3, 4, 7, 8, 9]);
1000        check(&block1, &allow2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1001        check(&block1, &block2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1002        check(&block1, &mixed2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1003        check(&mixed1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
1004        check(&mixed1, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1005        check(&mixed1, &mixed2, &[2, 3, 6, 7, 8, 9]);
1006        check(&allow2, &block2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1007        check(&allow2, &mixed2, &[2, 3, 4, 5, 6, 7, 8]);
1008        check(&block2, &mixed2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1009    }
1010
1011    #[test]
1012    fn test_map_insert_range() {
1013        let ranges = &[
1014            (0..10),
1015            (40..500),
1016            ((u32::MAX as u64 - 10)..(u32::MAX as u64 + 20)),
1017        ];
1018
1019        for range in ranges {
1020            let mut mask = RowIdTreeMap::default();
1021
1022            let count = mask.insert_range(range.clone());
1023            let expected = range.end - range.start;
1024            assert_eq!(count, expected);
1025
1026            let count = mask.insert_range(range.clone());
1027            assert_eq!(count, 0);
1028
1029            let new_range = range.start + 5..range.end + 5;
1030            let count = mask.insert_range(new_range.clone());
1031            assert_eq!(count, 5);
1032        }
1033
1034        let mut mask = RowIdTreeMap::default();
1035        let count = mask.insert_range(..10);
1036        assert_eq!(count, 10);
1037        assert!(mask.contains(0));
1038
1039        let count = mask.insert_range(20..=24);
1040        assert_eq!(count, 5);
1041
1042        mask.insert_fragment(0);
1043        let count = mask.insert_range(100..200);
1044        assert_eq!(count, 0);
1045    }
1046
1047    #[test]
1048    fn test_map_remove() {
1049        let mut mask = RowIdTreeMap::default();
1050
1051        assert!(!mask.remove(20));
1052
1053        mask.insert(20);
1054        assert!(mask.contains(20));
1055        assert!(mask.remove(20));
1056        assert!(!mask.contains(20));
1057
1058        mask.insert_range(10..=20);
1059        assert!(mask.contains(15));
1060        assert!(mask.remove(15));
1061        assert!(!mask.contains(15));
1062
1063        // We don't test removing from a full fragment, because that would take
1064        // a lot of memory.
1065    }
1066
1067    proptest::proptest! {
1068        #[test]
1069        fn test_map_serialization_roundtrip(
1070            values in proptest::collection::vec(
1071                (0..u32::MAX, proptest::option::of(proptest::collection::vec(0..u32::MAX, 0..1000))),
1072                0..10
1073            )
1074        ) {
1075            let mut mask = RowIdTreeMap::default();
1076            for (fragment, rows) in values {
1077                if let Some(rows) = rows {
1078                    let bitmap = RoaringBitmap::from_iter(rows);
1079                    mask.insert_bitmap(fragment, bitmap);
1080                } else {
1081                    mask.insert_fragment(fragment);
1082                }
1083            }
1084
1085            let mut data = Vec::new();
1086            mask.serialize_into(&mut data).unwrap();
1087            let deserialized = RowIdTreeMap::deserialize_from(data.as_slice()).unwrap();
1088            prop_assert_eq!(mask, deserialized);
1089        }
1090
1091        #[test]
1092        fn test_map_intersect(
1093            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1094            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1095            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1096            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1097        ) {
1098            let mut left = RowIdTreeMap::default();
1099            for fragment in left_full_fragments.clone() {
1100                left.insert_fragment(fragment);
1101            }
1102            left.extend(left_rows.iter().copied());
1103
1104            let mut right = RowIdTreeMap::default();
1105            for fragment in right_full_fragments.clone() {
1106                right.insert_fragment(fragment);
1107            }
1108            right.extend(right_rows.iter().copied());
1109
1110            let mut expected = RowIdTreeMap::default();
1111            for fragment in &left_full_fragments {
1112                if right_full_fragments.contains(fragment) {
1113                    expected.insert_fragment(*fragment);
1114                }
1115            }
1116
1117            let left_in_right = left_rows.iter().filter(|row| {
1118                right_rows.contains(row)
1119                    || right_full_fragments.contains(&((*row >> 32) as u32))
1120            });
1121            expected.extend(left_in_right);
1122            let right_in_left = right_rows.iter().filter(|row| {
1123                left_rows.contains(row)
1124                    || left_full_fragments.contains(&((*row >> 32) as u32))
1125            });
1126            expected.extend(right_in_left);
1127
1128            let actual = left & right;
1129            prop_assert_eq!(expected, actual);
1130        }
1131
1132        #[test]
1133        fn test_map_union(
1134            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1135            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1136            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1137            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1138        ) {
1139            let mut left = RowIdTreeMap::default();
1140            for fragment in left_full_fragments.clone() {
1141                left.insert_fragment(fragment);
1142            }
1143            left.extend(left_rows.iter().copied());
1144
1145            let mut right = RowIdTreeMap::default();
1146            for fragment in right_full_fragments.clone() {
1147                right.insert_fragment(fragment);
1148            }
1149            right.extend(right_rows.iter().copied());
1150
1151            let mut expected = RowIdTreeMap::default();
1152            for fragment in left_full_fragments {
1153                expected.insert_fragment(fragment);
1154            }
1155            for fragment in right_full_fragments {
1156                expected.insert_fragment(fragment);
1157            }
1158
1159            let combined_rows = left_rows.iter().chain(right_rows.iter());
1160            expected.extend(combined_rows);
1161
1162            let actual = left | right;
1163            for actual_key_val in &actual.inner {
1164                proptest::prop_assert!(expected.inner.contains_key(actual_key_val.0));
1165                let expected_val = expected.inner.get(actual_key_val.0).unwrap();
1166                prop_assert_eq!(
1167                    actual_key_val.1,
1168                    expected_val,
1169                    "error on key {}",
1170                    actual_key_val.0
1171                );
1172            }
1173            prop_assert_eq!(expected, actual);
1174        }
1175
1176        #[test]
1177        fn test_map_subassign_rows(
1178            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1179            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1180            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1181        ) {
1182            let mut left = RowIdTreeMap::default();
1183            for fragment in left_full_fragments {
1184                left.insert_fragment(fragment);
1185            }
1186            left.extend(left_rows.iter().copied());
1187
1188            let mut right = RowIdTreeMap::default();
1189            right.extend(right_rows.iter().copied());
1190
1191            let mut expected = left.clone();
1192            for row in right_rows {
1193                expected.remove(row);
1194            }
1195
1196            left -= &right;
1197            prop_assert_eq!(expected, left);
1198        }
1199
1200        #[test]
1201        fn test_map_subassign_frags(
1202            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1203            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1204            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1205        ) {
1206            let mut left = RowIdTreeMap::default();
1207            for fragment in left_full_fragments {
1208                left.insert_fragment(fragment);
1209            }
1210            left.extend(left_rows.iter().copied());
1211
1212            let mut right = RowIdTreeMap::default();
1213            for fragment in right_full_fragments.clone() {
1214                right.insert_fragment(fragment);
1215            }
1216
1217            let mut expected = left.clone();
1218            for fragment in right_full_fragments {
1219                expected.inner.remove(&fragment);
1220            }
1221
1222            left -= &right;
1223            prop_assert_eq!(expected, left);
1224        }
1225
1226    }
1227
1228    #[test]
1229    fn test_iter_ids() {
1230        let mut mask = RowIdMask::default();
1231        assert!(mask.iter_ids().is_none());
1232
1233        // Test with just an allow list
1234        let mut allow_list = RowIdTreeMap::default();
1235        allow_list.extend([1, 5, 10].iter().copied());
1236        mask.allow_list = Some(allow_list);
1237
1238        let ids: Vec<_> = mask.iter_ids().unwrap().collect();
1239        assert_eq!(
1240            ids,
1241            vec![
1242                RowAddress::new_from_parts(0, 1),
1243                RowAddress::new_from_parts(0, 5),
1244                RowAddress::new_from_parts(0, 10)
1245            ]
1246        );
1247
1248        // Test with both allow list and block list
1249        let mut block_list = RowIdTreeMap::default();
1250        block_list.extend([5].iter().copied());
1251        mask.block_list = Some(block_list);
1252
1253        let ids: Vec<_> = mask.iter_ids().unwrap().collect();
1254        assert_eq!(
1255            ids,
1256            vec![
1257                RowAddress::new_from_parts(0, 1),
1258                RowAddress::new_from_parts(0, 10)
1259            ]
1260        );
1261
1262        // Test with full fragment in block list
1263        let mut block_list = RowIdTreeMap::default();
1264        block_list.insert_fragment(0);
1265        mask.block_list = Some(block_list);
1266        assert!(mask.iter_ids().is_none());
1267
1268        // Test with full fragment in allow list
1269        mask.block_list = None;
1270        let mut allow_list = RowIdTreeMap::default();
1271        allow_list.insert_fragment(0);
1272        mask.allow_list = Some(allow_list);
1273        assert!(mask.iter_ids().is_none());
1274    }
1275}