lance_core/utils/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashSet;
5use std::io::Write;
6use std::iter;
7use std::ops::{Range, RangeBounds};
8use std::{collections::BTreeMap, io::Read};
9
10use arrow_array::{Array, BinaryArray, GenericBinaryArray};
11use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer};
12use byteorder::{ReadBytesExt, WriteBytesExt};
13use deepsize::DeepSizeOf;
14use roaring::{MultiOps, RoaringBitmap, RoaringTreemap};
15
16use crate::Result;
17
18use super::address::RowAddress;
19
20/// A row id mask to select or deselect particular row ids
21///
22/// If both the allow_list and the block_list are Some then the only selected
23/// row ids are those that are in the allow_list but not in the block_list
24/// (the block_list takes precedence)
25///
26/// If both the allow_list and the block_list are None (the default) then
27/// all row ids are selected
28#[derive(Clone, Debug, Default, DeepSizeOf)]
29pub struct RowIdMask {
30    /// If Some then only these row ids are selected
31    pub allow_list: Option<RowIdTreeMap>,
32    /// If Some then these row ids are not selected.
33    pub block_list: Option<RowIdTreeMap>,
34}
35
36impl RowIdMask {
37    // Create a mask allowing all rows, this is an alias for [default]
38    pub fn all_rows() -> Self {
39        Self::default()
40    }
41
42    // Create a mask that doesn't allow anything
43    pub fn allow_nothing() -> Self {
44        Self {
45            allow_list: Some(RowIdTreeMap::new()),
46            block_list: None,
47        }
48    }
49
50    // Create a mask from an allow list
51    pub fn from_allowed(allow_list: RowIdTreeMap) -> Self {
52        Self {
53            allow_list: Some(allow_list),
54            block_list: None,
55        }
56    }
57
58    // Create a mask from a block list
59    pub fn from_block(block_list: RowIdTreeMap) -> Self {
60        Self {
61            allow_list: None,
62            block_list: Some(block_list),
63        }
64    }
65
66    // If there is both a block list and an allow list then collapse into just an allow list
67    pub fn normalize(self) -> Self {
68        if let Self {
69            allow_list: Some(mut allow_list),
70            block_list: Some(block_list),
71        } = self
72        {
73            allow_list -= &block_list;
74            Self {
75                allow_list: Some(allow_list),
76                block_list: None,
77            }
78        } else {
79            self
80        }
81    }
82
83    /// True if the row_id is selected by the mask, false otherwise
84    pub fn selected(&self, row_id: u64) -> bool {
85        match (&self.allow_list, &self.block_list) {
86            (None, None) => true,
87            (Some(allow_list), None) => allow_list.contains(row_id),
88            (None, Some(block_list)) => !block_list.contains(row_id),
89            (Some(allow_list), Some(block_list)) => {
90                allow_list.contains(row_id) && !block_list.contains(row_id)
91            }
92        }
93    }
94
95    /// Return the indices of the input row ids that were valid
96    pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
97        let enumerated_ids = row_ids.enumerate();
98        match (&self.block_list, &self.allow_list) {
99            (Some(block_list), Some(allow_list)) => {
100                // Only take rows that are both in the allow list and not in the block list
101                enumerated_ids
102                    .filter(|(_, row_id)| {
103                        !block_list.contains(**row_id) && allow_list.contains(**row_id)
104                    })
105                    .map(|(idx, _)| idx as u64)
106                    .collect()
107            }
108            (Some(block_list), None) => {
109                // Take rows that are not in the block list
110                enumerated_ids
111                    .filter(|(_, row_id)| !block_list.contains(**row_id))
112                    .map(|(idx, _)| idx as u64)
113                    .collect()
114            }
115            (None, Some(allow_list)) => {
116                // Take rows that are in the allow list
117                enumerated_ids
118                    .filter(|(_, row_id)| allow_list.contains(**row_id))
119                    .map(|(idx, _)| idx as u64)
120                    .collect()
121            }
122            (None, None) => {
123                // We should not encounter this case because callers should
124                // check is_empty first.
125                panic!("selected_indices called but prefilter has nothing to filter with")
126            }
127        }
128    }
129
130    /// Also block the given ids
131    pub fn also_block(self, block_list: RowIdTreeMap) -> Self {
132        if block_list.is_empty() {
133            return self;
134        }
135        if let Some(existing) = self.block_list {
136            Self {
137                block_list: Some(existing | block_list),
138                allow_list: self.allow_list,
139            }
140        } else {
141            Self {
142                block_list: Some(block_list),
143                allow_list: self.allow_list,
144            }
145        }
146    }
147
148    /// Also allow the given ids
149    pub fn also_allow(self, allow_list: RowIdTreeMap) -> Self {
150        if let Some(existing) = self.allow_list {
151            Self {
152                block_list: self.block_list,
153                allow_list: Some(existing | allow_list),
154            }
155        } else {
156            Self {
157                block_list: self.block_list,
158                // allow_list = None means "all rows allowed" and so allowing
159                //              more rows is meaningless
160                allow_list: None,
161            }
162        }
163    }
164
165    /// Convert a mask into an arrow array
166    ///
167    /// A row id mask is not very arrow-compatible.  We can't make it a batch with
168    /// two columns because the block list and allow list will have different lengths.  Also,
169    /// there is no Arrow type for compressed bitmaps.
170    ///
171    /// However, we need to shove it into some kind of Arrow container to pass it along the
172    /// datafusion stream.  Perhaps, in the future, we can add row id masks as first class
173    /// types in datafusion, and this can be passed along as a mask / selection vector.
174    ///
175    /// We serialize this as a variable length binary array with two items.  The first item
176    /// is the block list and the second item is the allow list.
177    pub fn into_arrow(&self) -> Result<BinaryArray> {
178        let block_list_length = self
179            .block_list
180            .as_ref()
181            .map(|bl| bl.serialized_size())
182            .unwrap_or(0);
183        let allow_list_length = self
184            .allow_list
185            .as_ref()
186            .map(|al| al.serialized_size())
187            .unwrap_or(0);
188        let lengths = vec![block_list_length, allow_list_length];
189        let offsets = OffsetBuffer::from_lengths(lengths);
190        let mut value_bytes = vec![0; block_list_length + allow_list_length];
191        let mut validity = vec![false, false];
192        if let Some(block_list) = &self.block_list {
193            validity[0] = true;
194            block_list.serialize_into(&mut value_bytes[0..])?;
195        }
196        if let Some(allow_list) = &self.allow_list {
197            validity[1] = true;
198            allow_list.serialize_into(&mut value_bytes[block_list_length..])?;
199        }
200        let values = Buffer::from(value_bytes);
201        let nulls = NullBuffer::from(validity);
202        Ok(BinaryArray::try_new(offsets, values, Some(nulls))?)
203    }
204
205    /// Deserialize a row id mask from Arrow
206    pub fn from_arrow(array: &GenericBinaryArray<i32>) -> Result<Self> {
207        let block_list = if array.is_null(0) {
208            None
209        } else {
210            Some(RowIdTreeMap::deserialize_from(array.value(0)))
211        }
212        .transpose()?;
213
214        let allow_list = if array.is_null(1) {
215            None
216        } else {
217            Some(RowIdTreeMap::deserialize_from(array.value(1)))
218        }
219        .transpose()?;
220        Ok(Self {
221            block_list,
222            allow_list,
223        })
224    }
225
226    /// Return the maximum number of row ids that could be selected by this mask
227    ///
228    /// Will be None if there is no allow list
229    pub fn max_len(&self) -> Option<u64> {
230        if let Some(allow_list) = &self.allow_list {
231            // If there is a block list we could theoretically intersect the two
232            // but it's not clear if that is worth the effort.  Feel free to add later.
233            allow_list.len()
234        } else {
235            None
236        }
237    }
238
239    /// Iterate over the row ids that are selected by the mask
240    ///
241    /// This is only possible if there is an allow list and neither the
242    /// allow list nor the block list contain any "full fragment" blocks.
243    ///
244    /// TODO: We could probably still iterate efficiently even if the block
245    /// list contains "full fragment" blocks but that would require some
246    /// extra logic.
247    pub fn iter_ids(&self) -> Option<Box<dyn Iterator<Item = RowAddress> + '_>> {
248        if let Some(mut allow_iter) = self.allow_list.as_ref().and_then(|list| list.row_ids()) {
249            if let Some(block_list) = &self.block_list {
250                if let Some(block_iter) = block_list.row_ids() {
251                    let mut block_iter = block_iter.peekable();
252                    Some(Box::new(iter::from_fn(move || {
253                        for allow_id in allow_iter.by_ref() {
254                            while let Some(block_id) = block_iter.peek() {
255                                if *block_id >= allow_id {
256                                    break;
257                                }
258                                block_iter.next();
259                            }
260                            if let Some(block_id) = block_iter.peek() {
261                                if *block_id == allow_id {
262                                    continue;
263                                }
264                            }
265                            return Some(allow_id);
266                        }
267                        None
268                    })))
269                } else {
270                    // There is a block list but we can't iterate over it, give up
271                    None
272                }
273            } else {
274                // There is no block list, use the allow list
275                Some(Box::new(allow_iter))
276            }
277        } else {
278            None
279        }
280    }
281}
282
283impl std::ops::Not for RowIdMask {
284    type Output = Self;
285
286    fn not(self) -> Self::Output {
287        Self {
288            block_list: self.allow_list,
289            allow_list: self.block_list,
290        }
291    }
292}
293
294impl std::ops::BitAnd for RowIdMask {
295    type Output = Self;
296
297    fn bitand(self, rhs: Self) -> Self::Output {
298        let block_list = match (self.block_list, rhs.block_list) {
299            (None, None) => None,
300            (Some(lhs), None) => Some(lhs),
301            (None, Some(rhs)) => Some(rhs),
302            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
303        };
304        let allow_list = match (self.allow_list, rhs.allow_list) {
305            (None, None) => None,
306            (Some(lhs), None) => Some(lhs),
307            (None, Some(rhs)) => Some(rhs),
308            (Some(lhs), Some(rhs)) => Some(lhs & rhs),
309        };
310        Self {
311            block_list,
312            allow_list,
313        }
314    }
315}
316
317impl std::ops::BitOr for RowIdMask {
318    type Output = Self;
319
320    fn bitor(self, rhs: Self) -> Self::Output {
321        let this = self.normalize();
322        let rhs = rhs.normalize();
323        let block_list = if let Some(mut self_block_list) = this.block_list {
324            match (&rhs.allow_list, rhs.block_list) {
325                // If RHS is allow all, then our block list disappears
326                (None, None) => None,
327                // If RHS is allow list, remove allowed from our block list
328                (Some(allow_list), None) => {
329                    self_block_list -= allow_list;
330                    Some(self_block_list)
331                }
332                // If RHS is block list, intersect
333                (None, Some(block_list)) => Some(self_block_list & block_list),
334                // We normalized to avoid this path
335                (Some(_), Some(_)) => unreachable!(),
336            }
337        } else if let Some(mut rhs_block_list) = rhs.block_list {
338            if let Some(allow_list) = &this.allow_list {
339                rhs_block_list -= allow_list;
340                Some(rhs_block_list)
341            } else {
342                Some(rhs_block_list)
343            }
344        } else {
345            None
346        };
347
348        let allow_list = match (this.allow_list, rhs.allow_list) {
349            (None, None) => None,
350            // Remember that an allow list of None means "all rows" and
351            // so "all rows" | "some rows" is always "all rows"
352            (Some(_), None) => None,
353            (None, Some(_)) => None,
354            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
355        };
356        Self {
357            block_list,
358            allow_list,
359        }
360    }
361}
362
363/// A collection of row ids.
364///
365/// These row ids may either be stable-style (where they can be an incrementing
366/// u64 sequence) or address style, where they are a fragment id and a row offset.
367/// When address style, this supports setting entire fragments as selected,
368/// without needing to enumerate all the ids in the fragment.
369///
370/// This is similar to a [RoaringTreemap] but it is optimized for the case where
371/// entire fragments are selected or deselected.
372#[derive(Clone, Debug, Default, PartialEq, DeepSizeOf)]
373pub struct RowIdTreeMap {
374    /// The contents of the set. If there is a pair (k, Full) then the entire
375    /// fragment k is selected. If there is a pair (k, Partial(v)) then the
376    /// fragment k has the selected rows in v.
377    inner: BTreeMap<u32, RowAddrSelection>,
378}
379
380#[derive(Clone, Debug, PartialEq)]
381enum RowAddrSelection {
382    Full,
383    Partial(RoaringBitmap),
384}
385
386impl DeepSizeOf for RowAddrSelection {
387    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
388        match self {
389            Self::Full => 0,
390            Self::Partial(bitmap) => bitmap.serialized_size(),
391        }
392    }
393}
394
395impl RowAddrSelection {
396    fn union_all(selections: &[&Self]) -> Self {
397        let mut is_full = false;
398
399        let res = Self::Partial(
400            selections
401                .iter()
402                .filter_map(|selection| match selection {
403                    Self::Full => {
404                        is_full = true;
405                        None
406                    }
407                    Self::Partial(bitmap) => Some(bitmap),
408                })
409                .union(),
410        );
411
412        if is_full {
413            Self::Full
414        } else {
415            res
416        }
417    }
418}
419
420impl RowIdTreeMap {
421    /// Create an empty set
422    pub fn new() -> Self {
423        Self::default()
424    }
425
426    pub fn is_empty(&self) -> bool {
427        self.inner.is_empty()
428    }
429
430    /// The number of rows in the map
431    ///
432    /// If there are any "full fragment" items then this is unknown and None is returned
433    pub fn len(&self) -> Option<u64> {
434        self.inner
435            .values()
436            .map(|row_addr_selection| match row_addr_selection {
437                RowAddrSelection::Full => None,
438                RowAddrSelection::Partial(indices) => Some(indices.len()),
439            })
440            .try_fold(0_u64, |acc, next| next.map(|next| next + acc))
441    }
442
443    /// An iterator of row ids
444    ///
445    /// If there are any "full fragment" items then this can't be calculated and None
446    /// is returned
447    pub fn row_ids(&self) -> Option<impl Iterator<Item = RowAddress> + '_> {
448        let inner_iters = self
449            .inner
450            .iter()
451            .filter_map(|(frag_id, row_addr_selection)| match row_addr_selection {
452                RowAddrSelection::Full => None,
453                RowAddrSelection::Partial(bitmap) => Some(
454                    bitmap
455                        .iter()
456                        .map(|row_offset| RowAddress::new_from_parts(*frag_id, row_offset)),
457                ),
458            })
459            .collect::<Vec<_>>();
460        if inner_iters.len() != self.inner.len() {
461            None
462        } else {
463            Some(inner_iters.into_iter().flatten())
464        }
465    }
466
467    /// Insert a single value into the set
468    ///
469    /// Returns true if the value was not already in the set.
470    ///
471    /// ```rust
472    /// use lance_core::utils::mask::RowIdTreeMap;
473    ///
474    /// let mut set = RowIdTreeMap::new();
475    /// assert_eq!(set.insert(10), true);
476    /// assert_eq!(set.insert(10), false);
477    /// assert_eq!(set.contains(10), true);
478    /// ```
479    pub fn insert(&mut self, value: u64) -> bool {
480        let fragment = (value >> 32) as u32;
481        let row_addr = value as u32;
482        match self.inner.get_mut(&fragment) {
483            None => {
484                let mut set = RoaringBitmap::new();
485                set.insert(row_addr);
486                self.inner.insert(fragment, RowAddrSelection::Partial(set));
487                true
488            }
489            Some(RowAddrSelection::Full) => false,
490            Some(RowAddrSelection::Partial(set)) => set.insert(row_addr),
491        }
492    }
493
494    /// Insert a range of values into the set
495    pub fn insert_range<R: RangeBounds<u64>>(&mut self, range: R) -> u64 {
496        // Separate the start and end into high and low bits.
497        let (mut start_high, mut start_low) = match range.start_bound() {
498            std::ops::Bound::Included(&start) => ((start >> 32) as u32, start as u32),
499            std::ops::Bound::Excluded(&start) => {
500                let start = start.saturating_add(1);
501                ((start >> 32) as u32, start as u32)
502            }
503            std::ops::Bound::Unbounded => (0, 0),
504        };
505
506        let (end_high, end_low) = match range.end_bound() {
507            std::ops::Bound::Included(&end) => ((end >> 32) as u32, end as u32),
508            std::ops::Bound::Excluded(&end) => {
509                let end = end.saturating_sub(1);
510                ((end >> 32) as u32, end as u32)
511            }
512            std::ops::Bound::Unbounded => (u32::MAX, u32::MAX),
513        };
514
515        let mut count = 0;
516
517        while start_high <= end_high {
518            let start = start_low;
519            let end = if start_high == end_high {
520                end_low
521            } else {
522                u32::MAX
523            };
524            let fragment = start_high;
525            match self.inner.get_mut(&fragment) {
526                None => {
527                    let mut set = RoaringBitmap::new();
528                    count += set.insert_range(start..=end);
529                    self.inner.insert(fragment, RowAddrSelection::Partial(set));
530                }
531                Some(RowAddrSelection::Full) => {}
532                Some(RowAddrSelection::Partial(set)) => {
533                    count += set.insert_range(start..=end);
534                }
535            }
536            start_high += 1;
537            start_low = 0;
538        }
539
540        count
541    }
542
543    /// Add a bitmap for a single fragment
544    pub fn insert_bitmap(&mut self, fragment: u32, bitmap: RoaringBitmap) {
545        self.inner
546            .insert(fragment, RowAddrSelection::Partial(bitmap));
547    }
548
549    /// Add a whole fragment to the set
550    pub fn insert_fragment(&mut self, fragment_id: u32) {
551        self.inner.insert(fragment_id, RowAddrSelection::Full);
552    }
553
554    pub fn get_fragment_bitmap(&self, fragment_id: u32) -> Option<&RoaringBitmap> {
555        match self.inner.get(&fragment_id) {
556            None => None,
557            Some(RowAddrSelection::Full) => None,
558            Some(RowAddrSelection::Partial(set)) => Some(set),
559        }
560    }
561
562    /// Returns whether the set contains the given value
563    pub fn contains(&self, value: u64) -> bool {
564        let upper = (value >> 32) as u32;
565        let lower = value as u32;
566        match self.inner.get(&upper) {
567            None => false,
568            Some(RowAddrSelection::Full) => true,
569            Some(RowAddrSelection::Partial(fragment_set)) => fragment_set.contains(lower),
570        }
571    }
572
573    pub fn remove(&mut self, value: u64) -> bool {
574        let upper = (value >> 32) as u32;
575        let lower = value as u32;
576        match self.inner.get_mut(&upper) {
577            None => false,
578            Some(RowAddrSelection::Full) => {
579                let mut set = RoaringBitmap::full();
580                set.remove(lower);
581                self.inner.insert(upper, RowAddrSelection::Partial(set));
582                true
583            }
584            Some(RowAddrSelection::Partial(lower_set)) => {
585                let removed = lower_set.remove(lower);
586                if lower_set.is_empty() {
587                    self.inner.remove(&upper);
588                }
589                removed
590            }
591        }
592    }
593
594    pub fn retain_fragments(&mut self, frag_ids: impl IntoIterator<Item = u32>) {
595        let frag_id_set = frag_ids.into_iter().collect::<HashSet<_>>();
596        self.inner
597            .retain(|frag_id, _| frag_id_set.contains(frag_id));
598    }
599
600    /// Compute the serialized size of the set.
601    pub fn serialized_size(&self) -> usize {
602        // Starts at 4 because of the u32 num_entries
603        let mut size = 4;
604        for set in self.inner.values() {
605            // Each entry is 8 bytes for the fragment id and the bitmap size
606            size += 8;
607            if let RowAddrSelection::Partial(set) = set {
608                size += set.serialized_size();
609            }
610        }
611        size
612    }
613
614    /// Serialize the set into the given buffer
615    ///
616    /// The serialization format is stable and used for index serialization
617    ///
618    /// The serialization format is:
619    /// * u32: num_entries
620    ///
621    /// for each entry:
622    ///   * u32: fragment_id
623    ///   * u32: bitmap size
624    ///   * \[u8\]: bitmap
625    ///
626    /// If bitmap size is zero then the entire fragment is selected.
627    pub fn serialize_into<W: Write>(&self, mut writer: W) -> Result<()> {
628        writer.write_u32::<byteorder::LittleEndian>(self.inner.len() as u32)?;
629        for (fragment, set) in &self.inner {
630            writer.write_u32::<byteorder::LittleEndian>(*fragment)?;
631            if let RowAddrSelection::Partial(set) = set {
632                writer.write_u32::<byteorder::LittleEndian>(set.serialized_size() as u32)?;
633                set.serialize_into(&mut writer)?;
634            } else {
635                writer.write_u32::<byteorder::LittleEndian>(0)?;
636            }
637        }
638        Ok(())
639    }
640
641    /// Deserialize the set from the given buffer
642    pub fn deserialize_from<R: Read>(mut reader: R) -> Result<Self> {
643        let num_entries = reader.read_u32::<byteorder::LittleEndian>()?;
644        let mut inner = BTreeMap::new();
645        for _ in 0..num_entries {
646            let fragment = reader.read_u32::<byteorder::LittleEndian>()?;
647            let bitmap_size = reader.read_u32::<byteorder::LittleEndian>()?;
648            if bitmap_size == 0 {
649                inner.insert(fragment, RowAddrSelection::Full);
650            } else {
651                let mut buffer = vec![0; bitmap_size as usize];
652                reader.read_exact(&mut buffer)?;
653                let set = RoaringBitmap::deserialize_from(&buffer[..])?;
654                inner.insert(fragment, RowAddrSelection::Partial(set));
655            }
656        }
657        Ok(Self { inner })
658    }
659
660    pub fn union_all(maps: &[&Self]) -> Self {
661        let mut new_map = BTreeMap::new();
662
663        for map in maps {
664            for (fragment, selection) in &map.inner {
665                new_map
666                    .entry(fragment)
667                    // I hate this allocation, but I can't think of a better way
668                    .or_insert_with(|| Vec::with_capacity(maps.len()))
669                    .push(selection);
670            }
671        }
672
673        let new_map = new_map
674            .into_iter()
675            .map(|(&fragment, selections)| (fragment, RowAddrSelection::union_all(&selections)))
676            .collect();
677
678        Self { inner: new_map }
679    }
680
681    /// Apply a mask to the row ids
682    ///
683    /// If there is an allow list then this will intersect the set with the allow list
684    /// If there is a block list then this will subtract the block list from the set
685    pub fn mask(&mut self, mask: &RowIdMask) {
686        if let Some(allow_list) = &mask.allow_list {
687            *self &= allow_list;
688        }
689        if let Some(block_list) = &mask.block_list {
690            *self -= block_list;
691        }
692    }
693
694    /// Convert the set into an iterator of row addrs
695    ///
696    /// # Safety
697    ///
698    /// This is unsafe because if any of the inner RowAddrSelection elements
699    /// is not a Partial then the iterator will panic because we don't know
700    /// the size of the bitmap.
701    pub unsafe fn into_addr_iter(self) -> impl Iterator<Item = u64> {
702        self.inner
703            .into_iter()
704            .flat_map(|(fragment, selection)| match selection {
705                RowAddrSelection::Full => panic!("Size of full fragment is unknown"),
706                RowAddrSelection::Partial(bitmap) => bitmap.into_iter().map(move |val| {
707                    let fragment = fragment as u64;
708                    let row_offset = val as u64;
709                    (fragment << 32) | row_offset
710                }),
711            })
712    }
713}
714
715impl std::ops::BitOr<Self> for RowIdTreeMap {
716    type Output = Self;
717
718    fn bitor(mut self, rhs: Self) -> Self::Output {
719        self |= rhs;
720        self
721    }
722}
723
724impl std::ops::BitOrAssign<Self> for RowIdTreeMap {
725    fn bitor_assign(&mut self, rhs: Self) {
726        for (fragment, rhs_set) in &rhs.inner {
727            let lhs_set = self.inner.get_mut(fragment);
728            if let Some(lhs_set) = lhs_set {
729                match lhs_set {
730                    RowAddrSelection::Full => {
731                        // If the fragment is already selected then there is nothing to do
732                    }
733                    RowAddrSelection::Partial(lhs_bitmap) => match rhs_set {
734                        RowAddrSelection::Full => {
735                            *lhs_set = RowAddrSelection::Full;
736                        }
737                        RowAddrSelection::Partial(rhs_set) => {
738                            *lhs_bitmap |= rhs_set;
739                        }
740                    },
741                }
742            } else {
743                self.inner.insert(*fragment, rhs_set.clone());
744            }
745        }
746    }
747}
748
749impl std::ops::BitAnd<Self> for RowIdTreeMap {
750    type Output = Self;
751
752    fn bitand(mut self, rhs: Self) -> Self::Output {
753        self &= &rhs;
754        self
755    }
756}
757
758impl std::ops::BitAndAssign<&Self> for RowIdTreeMap {
759    fn bitand_assign(&mut self, rhs: &Self) {
760        // Remove fragment that aren't on the RHS
761        self.inner
762            .retain(|fragment, _| rhs.inner.contains_key(fragment));
763
764        // For fragments that are on the RHS, intersect the bitmaps
765        for (fragment, mut lhs_set) in &mut self.inner {
766            match (&mut lhs_set, rhs.inner.get(fragment)) {
767                (_, None) => {} // Already handled by retain
768                (_, Some(RowAddrSelection::Full)) => {
769                    // Everything selected on RHS, so can leave LHS untouched.
770                }
771                (RowAddrSelection::Partial(lhs_set), Some(RowAddrSelection::Partial(rhs_set))) => {
772                    *lhs_set &= rhs_set;
773                }
774                (RowAddrSelection::Full, Some(RowAddrSelection::Partial(rhs_set))) => {
775                    *lhs_set = RowAddrSelection::Partial(rhs_set.clone());
776                }
777            }
778        }
779        // Some bitmaps might now be empty. If they are, we should remove them.
780        self.inner.retain(|_, set| match set {
781            RowAddrSelection::Partial(set) => !set.is_empty(),
782            RowAddrSelection::Full => true,
783        });
784    }
785}
786
787impl std::ops::Sub<Self> for RowIdTreeMap {
788    type Output = Self;
789
790    fn sub(mut self, rhs: Self) -> Self {
791        self -= &rhs;
792        self
793    }
794}
795
796impl std::ops::SubAssign<&Self> for RowIdTreeMap {
797    fn sub_assign(&mut self, rhs: &Self) {
798        for (fragment, rhs_set) in &rhs.inner {
799            match self.inner.get_mut(fragment) {
800                None => {}
801                Some(RowAddrSelection::Full) => {
802                    // If the fragment is already selected then there is nothing to do
803                    match rhs_set {
804                        RowAddrSelection::Full => {
805                            self.inner.remove(fragment);
806                        }
807                        RowAddrSelection::Partial(rhs_set) => {
808                            // This generally won't be hit.
809                            let mut set = RoaringBitmap::full();
810                            set -= rhs_set;
811                            self.inner.insert(*fragment, RowAddrSelection::Partial(set));
812                        }
813                    }
814                }
815                Some(RowAddrSelection::Partial(lhs_set)) => match rhs_set {
816                    RowAddrSelection::Full => {
817                        self.inner.remove(fragment);
818                    }
819                    RowAddrSelection::Partial(rhs_set) => {
820                        *lhs_set -= rhs_set;
821                        if lhs_set.is_empty() {
822                            self.inner.remove(fragment);
823                        }
824                    }
825                },
826            }
827        }
828    }
829}
830
831impl FromIterator<u64> for RowIdTreeMap {
832    fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> Self {
833        let mut inner = BTreeMap::new();
834        for row_id in iter {
835            let upper = (row_id >> 32) as u32;
836            let lower = row_id as u32;
837            match inner.get_mut(&upper) {
838                None => {
839                    let mut set = RoaringBitmap::new();
840                    set.insert(lower);
841                    inner.insert(upper, RowAddrSelection::Partial(set));
842                }
843                Some(RowAddrSelection::Full) => {
844                    // If the fragment is already selected then there is nothing to do
845                }
846                Some(RowAddrSelection::Partial(set)) => {
847                    set.insert(lower);
848                }
849            }
850        }
851        Self { inner }
852    }
853}
854
855impl<'a> FromIterator<&'a u64> for RowIdTreeMap {
856    fn from_iter<T: IntoIterator<Item = &'a u64>>(iter: T) -> Self {
857        Self::from_iter(iter.into_iter().copied())
858    }
859}
860
861impl From<Range<u64>> for RowIdTreeMap {
862    fn from(range: Range<u64>) -> Self {
863        let mut map = Self::default();
864        map.insert_range(range);
865        map
866    }
867}
868
869impl From<RoaringTreemap> for RowIdTreeMap {
870    fn from(roaring: RoaringTreemap) -> Self {
871        let mut inner = BTreeMap::new();
872        for (fragment, set) in roaring.bitmaps() {
873            inner.insert(fragment, RowAddrSelection::Partial(set.clone()));
874        }
875        Self { inner }
876    }
877}
878
879impl Extend<u64> for RowIdTreeMap {
880    fn extend<T: IntoIterator<Item = u64>>(&mut self, iter: T) {
881        for row_id in iter {
882            let upper = (row_id >> 32) as u32;
883            let lower = row_id as u32;
884            match self.inner.get_mut(&upper) {
885                None => {
886                    let mut set = RoaringBitmap::new();
887                    set.insert(lower);
888                    self.inner.insert(upper, RowAddrSelection::Partial(set));
889                }
890                Some(RowAddrSelection::Full) => {
891                    // If the fragment is already selected then there is nothing to do
892                }
893                Some(RowAddrSelection::Partial(set)) => {
894                    set.insert(lower);
895                }
896            }
897        }
898    }
899}
900
901impl<'a> Extend<&'a u64> for RowIdTreeMap {
902    fn extend<T: IntoIterator<Item = &'a u64>>(&mut self, iter: T) {
903        self.extend(iter.into_iter().copied())
904    }
905}
906
907// Extending with RowIdTreeMap is basically a cumulative set union
908impl Extend<Self> for RowIdTreeMap {
909    fn extend<T: IntoIterator<Item = Self>>(&mut self, iter: T) {
910        for other in iter {
911            for (fragment, set) in other.inner {
912                match self.inner.get_mut(&fragment) {
913                    None => {
914                        self.inner.insert(fragment, set);
915                    }
916                    Some(RowAddrSelection::Full) => {
917                        // If the fragment is already selected then there is nothing to do
918                    }
919                    Some(RowAddrSelection::Partial(lhs_set)) => match set {
920                        RowAddrSelection::Full => {
921                            self.inner.insert(fragment, RowAddrSelection::Full);
922                        }
923                        RowAddrSelection::Partial(rhs_set) => {
924                            *lhs_set |= rhs_set;
925                        }
926                    },
927                }
928            }
929        }
930    }
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936    use proptest::prop_assert_eq;
937
938    #[test]
939    fn test_ops() {
940        let mask = RowIdMask::default();
941        assert!(mask.selected(1));
942        assert!(mask.selected(5));
943        let block_list = mask.also_block(RowIdTreeMap::from_iter(&[0, 5, 15]));
944        assert!(block_list.selected(1));
945        assert!(!block_list.selected(5));
946        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[0, 2, 5]));
947        assert!(!allow_list.selected(1));
948        assert!(allow_list.selected(5));
949        let combined = block_list & allow_list;
950        assert!(combined.selected(2));
951        assert!(!combined.selected(0));
952        assert!(!combined.selected(5));
953        let other = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
954        let combined = combined | other;
955        assert!(combined.selected(2));
956        assert!(combined.selected(3));
957        assert!(!combined.selected(0));
958        assert!(!combined.selected(5));
959
960        let block_list = RowIdMask::from_block(RowIdTreeMap::from_iter(&[0]));
961        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
962        let combined = block_list | allow_list;
963        assert!(combined.selected(1));
964    }
965
966    #[test]
967    fn test_logical_or() {
968        let allow1 = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[5, 6, 7, 8, 9]));
969        let block1 = RowIdMask::from_block(RowIdTreeMap::from_iter(&[5, 6]));
970        let mixed1 = allow1
971            .clone()
972            .also_block(block1.block_list.as_ref().unwrap().clone());
973        let allow2 = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[2, 3, 4, 5, 6, 7, 8]));
974        let block2 = RowIdMask::from_block(RowIdTreeMap::from_iter(&[4, 5]));
975        let mixed2 = allow2
976            .clone()
977            .also_block(block2.block_list.as_ref().unwrap().clone());
978
979        fn check(lhs: &RowIdMask, rhs: &RowIdMask, expected: &[u64]) {
980            for mask in [lhs.clone() | rhs.clone(), rhs.clone() | lhs.clone()] {
981                let values = (0..10)
982                    .filter(|val| mask.selected(*val))
983                    .collect::<Vec<_>>();
984                assert_eq!(&values, expected);
985            }
986        }
987
988        check(&allow1, &allow1, &[5, 6, 7, 8, 9]);
989        check(&block1, &block1, &[0, 1, 2, 3, 4, 7, 8, 9]);
990        check(&mixed1, &mixed1, &[7, 8, 9]);
991        check(&allow2, &allow2, &[2, 3, 4, 5, 6, 7, 8]);
992        check(&block2, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
993        check(&mixed2, &mixed2, &[2, 3, 6, 7, 8]);
994
995        check(&allow1, &block1, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
996        check(&allow1, &mixed1, &[5, 6, 7, 8, 9]);
997        check(&allow1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
998        check(&allow1, &block2, &[0, 1, 2, 3, 5, 6, 7, 8, 9]);
999        check(&allow1, &mixed2, &[2, 3, 5, 6, 7, 8, 9]);
1000        check(&block1, &mixed1, &[0, 1, 2, 3, 4, 7, 8, 9]);
1001        check(&block1, &allow2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1002        check(&block1, &block2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1003        check(&block1, &mixed2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1004        check(&mixed1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
1005        check(&mixed1, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1006        check(&mixed1, &mixed2, &[2, 3, 6, 7, 8, 9]);
1007        check(&allow2, &block2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1008        check(&allow2, &mixed2, &[2, 3, 4, 5, 6, 7, 8]);
1009        check(&block2, &mixed2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1010    }
1011
1012    #[test]
1013    fn test_map_insert_range() {
1014        let ranges = &[
1015            (0..10),
1016            (40..500),
1017            ((u32::MAX as u64 - 10)..(u32::MAX as u64 + 20)),
1018        ];
1019
1020        for range in ranges {
1021            let mut mask = RowIdTreeMap::default();
1022
1023            let count = mask.insert_range(range.clone());
1024            let expected = range.end - range.start;
1025            assert_eq!(count, expected);
1026
1027            let count = mask.insert_range(range.clone());
1028            assert_eq!(count, 0);
1029
1030            let new_range = range.start + 5..range.end + 5;
1031            let count = mask.insert_range(new_range.clone());
1032            assert_eq!(count, 5);
1033        }
1034
1035        let mut mask = RowIdTreeMap::default();
1036        let count = mask.insert_range(..10);
1037        assert_eq!(count, 10);
1038        assert!(mask.contains(0));
1039
1040        let count = mask.insert_range(20..=24);
1041        assert_eq!(count, 5);
1042
1043        mask.insert_fragment(0);
1044        let count = mask.insert_range(100..200);
1045        assert_eq!(count, 0);
1046    }
1047
1048    #[test]
1049    fn test_map_remove() {
1050        let mut mask = RowIdTreeMap::default();
1051
1052        assert!(!mask.remove(20));
1053
1054        mask.insert(20);
1055        assert!(mask.contains(20));
1056        assert!(mask.remove(20));
1057        assert!(!mask.contains(20));
1058
1059        mask.insert_range(10..=20);
1060        assert!(mask.contains(15));
1061        assert!(mask.remove(15));
1062        assert!(!mask.contains(15));
1063
1064        // We don't test removing from a full fragment, because that would take
1065        // a lot of memory.
1066    }
1067
1068    proptest::proptest! {
1069        #[test]
1070        fn test_map_serialization_roundtrip(
1071            values in proptest::collection::vec(
1072                (0..u32::MAX, proptest::option::of(proptest::collection::vec(0..u32::MAX, 0..1000))),
1073                0..10
1074            )
1075        ) {
1076            let mut mask = RowIdTreeMap::default();
1077            for (fragment, rows) in values {
1078                if let Some(rows) = rows {
1079                    let bitmap = RoaringBitmap::from_iter(rows);
1080                    mask.insert_bitmap(fragment, bitmap);
1081                } else {
1082                    mask.insert_fragment(fragment);
1083                }
1084            }
1085
1086            let mut data = Vec::new();
1087            mask.serialize_into(&mut data).unwrap();
1088            let deserialized = RowIdTreeMap::deserialize_from(data.as_slice()).unwrap();
1089            prop_assert_eq!(mask, deserialized);
1090        }
1091
1092        #[test]
1093        fn test_map_intersect(
1094            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1095            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1096            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1097            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1098        ) {
1099            let mut left = RowIdTreeMap::default();
1100            for fragment in left_full_fragments.clone() {
1101                left.insert_fragment(fragment);
1102            }
1103            left.extend(left_rows.iter().copied());
1104
1105            let mut right = RowIdTreeMap::default();
1106            for fragment in right_full_fragments.clone() {
1107                right.insert_fragment(fragment);
1108            }
1109            right.extend(right_rows.iter().copied());
1110
1111            let mut expected = RowIdTreeMap::default();
1112            for fragment in &left_full_fragments {
1113                if right_full_fragments.contains(fragment) {
1114                    expected.insert_fragment(*fragment);
1115                }
1116            }
1117
1118            let left_in_right = left_rows.iter().filter(|row| {
1119                right_rows.contains(row)
1120                    || right_full_fragments.contains(&((*row >> 32) as u32))
1121            });
1122            expected.extend(left_in_right);
1123            let right_in_left = right_rows.iter().filter(|row| {
1124                left_rows.contains(row)
1125                    || left_full_fragments.contains(&((*row >> 32) as u32))
1126            });
1127            expected.extend(right_in_left);
1128
1129            let actual = left & right;
1130            prop_assert_eq!(expected, actual);
1131        }
1132
1133        #[test]
1134        fn test_map_union(
1135            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1136            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1137            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1138            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1139        ) {
1140            let mut left = RowIdTreeMap::default();
1141            for fragment in left_full_fragments.clone() {
1142                left.insert_fragment(fragment);
1143            }
1144            left.extend(left_rows.iter().copied());
1145
1146            let mut right = RowIdTreeMap::default();
1147            for fragment in right_full_fragments.clone() {
1148                right.insert_fragment(fragment);
1149            }
1150            right.extend(right_rows.iter().copied());
1151
1152            let mut expected = RowIdTreeMap::default();
1153            for fragment in left_full_fragments {
1154                expected.insert_fragment(fragment);
1155            }
1156            for fragment in right_full_fragments {
1157                expected.insert_fragment(fragment);
1158            }
1159
1160            let combined_rows = left_rows.iter().chain(right_rows.iter());
1161            expected.extend(combined_rows);
1162
1163            let actual = left | right;
1164            for actual_key_val in &actual.inner {
1165                proptest::prop_assert!(expected.inner.contains_key(actual_key_val.0));
1166                let expected_val = expected.inner.get(actual_key_val.0).unwrap();
1167                prop_assert_eq!(
1168                    actual_key_val.1,
1169                    expected_val,
1170                    "error on key {}",
1171                    actual_key_val.0
1172                );
1173            }
1174            prop_assert_eq!(expected, actual);
1175        }
1176
1177        #[test]
1178        fn test_map_subassign_rows(
1179            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1180            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1181            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1182        ) {
1183            let mut left = RowIdTreeMap::default();
1184            for fragment in left_full_fragments {
1185                left.insert_fragment(fragment);
1186            }
1187            left.extend(left_rows.iter().copied());
1188
1189            let mut right = RowIdTreeMap::default();
1190            right.extend(right_rows.iter().copied());
1191
1192            let mut expected = left.clone();
1193            for row in right_rows {
1194                expected.remove(row);
1195            }
1196
1197            left -= &right;
1198            prop_assert_eq!(expected, left);
1199        }
1200
1201        #[test]
1202        fn test_map_subassign_frags(
1203            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1204            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1205            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1206        ) {
1207            let mut left = RowIdTreeMap::default();
1208            for fragment in left_full_fragments {
1209                left.insert_fragment(fragment);
1210            }
1211            left.extend(left_rows.iter().copied());
1212
1213            let mut right = RowIdTreeMap::default();
1214            for fragment in right_full_fragments.clone() {
1215                right.insert_fragment(fragment);
1216            }
1217
1218            let mut expected = left.clone();
1219            for fragment in right_full_fragments {
1220                expected.inner.remove(&fragment);
1221            }
1222
1223            left -= &right;
1224            prop_assert_eq!(expected, left);
1225        }
1226
1227    }
1228
1229    #[test]
1230    fn test_iter_ids() {
1231        let mut mask = RowIdMask::default();
1232        assert!(mask.iter_ids().is_none());
1233
1234        // Test with just an allow list
1235        let mut allow_list = RowIdTreeMap::default();
1236        allow_list.extend([1, 5, 10].iter().copied());
1237        mask.allow_list = Some(allow_list);
1238
1239        let ids: Vec<_> = mask.iter_ids().unwrap().collect();
1240        assert_eq!(
1241            ids,
1242            vec![
1243                RowAddress::new_from_parts(0, 1),
1244                RowAddress::new_from_parts(0, 5),
1245                RowAddress::new_from_parts(0, 10)
1246            ]
1247        );
1248
1249        // Test with both allow list and block list
1250        let mut block_list = RowIdTreeMap::default();
1251        block_list.extend([5].iter().copied());
1252        mask.block_list = Some(block_list);
1253
1254        let ids: Vec<_> = mask.iter_ids().unwrap().collect();
1255        assert_eq!(
1256            ids,
1257            vec![
1258                RowAddress::new_from_parts(0, 1),
1259                RowAddress::new_from_parts(0, 10)
1260            ]
1261        );
1262
1263        // Test with full fragment in block list
1264        let mut block_list = RowIdTreeMap::default();
1265        block_list.insert_fragment(0);
1266        mask.block_list = Some(block_list);
1267        assert!(mask.iter_ids().is_none());
1268
1269        // Test with full fragment in allow list
1270        mask.block_list = None;
1271        let mut allow_list = RowIdTreeMap::default();
1272        allow_list.insert_fragment(0);
1273        mask.allow_list = Some(allow_list);
1274        assert!(mask.iter_ids().is_none());
1275    }
1276}