lance_core/utils/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashSet;
5use std::io::Write;
6use std::ops::{Range, RangeBounds};
7use std::{collections::BTreeMap, io::Read};
8
9use arrow_array::{Array, BinaryArray, GenericBinaryArray};
10use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer};
11use byteorder::{ReadBytesExt, WriteBytesExt};
12use deepsize::DeepSizeOf;
13use roaring::{MultiOps, RoaringBitmap, RoaringTreemap};
14
15use crate::Result;
16
17use super::address::RowAddress;
18
19/// A row id mask to select or deselect particular row ids
20///
21/// If both the allow_list and the block_list are Some then the only selected
22/// row ids are those that are in the allow_list but not in the block_list
23/// (the block_list takes precedence)
24///
25/// If both the allow_list and the block_list are None (the default) then
26/// all row ids are selected
27#[derive(Clone, Debug, Default, DeepSizeOf)]
28pub struct RowIdMask {
29    /// If Some then only these row ids are selected
30    pub allow_list: Option<RowIdTreeMap>,
31    /// If Some then these row ids are not selected.
32    pub block_list: Option<RowIdTreeMap>,
33}
34
35impl RowIdMask {
36    // Create a mask allowing all rows, this is an alias for [default]
37    pub fn all_rows() -> Self {
38        Self::default()
39    }
40
41    // Create a mask that doesn't allow anything
42    pub fn allow_nothing() -> Self {
43        Self {
44            allow_list: Some(RowIdTreeMap::new()),
45            block_list: None,
46        }
47    }
48
49    // Create a mask from an allow list
50    pub fn from_allowed(allow_list: RowIdTreeMap) -> Self {
51        Self {
52            allow_list: Some(allow_list),
53            block_list: None,
54        }
55    }
56
57    // Create a mask from a block list
58    pub fn from_block(block_list: RowIdTreeMap) -> Self {
59        Self {
60            allow_list: None,
61            block_list: Some(block_list),
62        }
63    }
64
65    /// True if the row_id is selected by the mask, false otherwise
66    pub fn selected(&self, row_id: u64) -> bool {
67        match (&self.allow_list, &self.block_list) {
68            (None, None) => true,
69            (Some(allow_list), None) => allow_list.contains(row_id),
70            (None, Some(block_list)) => !block_list.contains(row_id),
71            (Some(allow_list), Some(block_list)) => {
72                allow_list.contains(row_id) && !block_list.contains(row_id)
73            }
74        }
75    }
76
77    /// Return the indices of the input row ids that were valid
78    pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
79        let enumerated_ids = row_ids.enumerate();
80        match (&self.block_list, &self.allow_list) {
81            (Some(block_list), Some(allow_list)) => {
82                // Only take rows that are both in the allow list and not in the block list
83                enumerated_ids
84                    .filter(|(_, row_id)| {
85                        !block_list.contains(**row_id) && allow_list.contains(**row_id)
86                    })
87                    .map(|(idx, _)| idx as u64)
88                    .collect()
89            }
90            (Some(block_list), None) => {
91                // Take rows that are not in the block list
92                enumerated_ids
93                    .filter(|(_, row_id)| !block_list.contains(**row_id))
94                    .map(|(idx, _)| idx as u64)
95                    .collect()
96            }
97            (None, Some(allow_list)) => {
98                // Take rows that are in the allow list
99                enumerated_ids
100                    .filter(|(_, row_id)| allow_list.contains(**row_id))
101                    .map(|(idx, _)| idx as u64)
102                    .collect()
103            }
104            (None, None) => {
105                // We should not encounter this case because callers should
106                // check is_empty first.
107                panic!("selected_indices called but prefilter has nothing to filter with")
108            }
109        }
110    }
111
112    /// Also block the given ids
113    pub fn also_block(self, block_list: RowIdTreeMap) -> Self {
114        if block_list.is_empty() {
115            return self;
116        }
117        if let Some(existing) = self.block_list {
118            Self {
119                block_list: Some(existing | block_list),
120                allow_list: self.allow_list,
121            }
122        } else {
123            Self {
124                block_list: Some(block_list),
125                allow_list: self.allow_list,
126            }
127        }
128    }
129
130    /// Also allow the given ids
131    pub fn also_allow(self, allow_list: RowIdTreeMap) -> Self {
132        if let Some(existing) = self.allow_list {
133            Self {
134                block_list: self.block_list,
135                allow_list: Some(existing | allow_list),
136            }
137        } else {
138            Self {
139                block_list: self.block_list,
140                // allow_list = None means "all rows allowed" and so allowing
141                //              more rows is meaningless
142                allow_list: None,
143            }
144        }
145    }
146
147    /// Convert a mask into an arrow array
148    ///
149    /// A row id mask is not very arrow-compatible.  We can't make it a batch with
150    /// two columns because the block list and allow list will have different lengths.  Also,
151    /// there is no Arrow type for compressed bitmaps.
152    ///
153    /// However, we need to shove it into some kind of Arrow container to pass it along the
154    /// datafusion stream.  Perhaps, in the future, we can add row id masks as first class
155    /// types in datafusion, and this can be passed along as a mask / selection vector.
156    ///
157    /// We serialize this as a variable length binary array with two items.  The first item
158    /// is the block list and the second item is the allow list.
159    pub fn into_arrow(&self) -> Result<BinaryArray> {
160        let block_list_length = self
161            .block_list
162            .as_ref()
163            .map(|bl| bl.serialized_size())
164            .unwrap_or(0);
165        let allow_list_length = self
166            .allow_list
167            .as_ref()
168            .map(|al| al.serialized_size())
169            .unwrap_or(0);
170        let lengths = vec![block_list_length, allow_list_length];
171        let offsets = OffsetBuffer::from_lengths(lengths);
172        let mut value_bytes = vec![0; block_list_length + allow_list_length];
173        let mut validity = vec![false, false];
174        if let Some(block_list) = &self.block_list {
175            validity[0] = true;
176            block_list.serialize_into(&mut value_bytes[0..])?;
177        }
178        if let Some(allow_list) = &self.allow_list {
179            validity[1] = true;
180            allow_list.serialize_into(&mut value_bytes[block_list_length..])?;
181        }
182        let values = Buffer::from(value_bytes);
183        let nulls = NullBuffer::from(validity);
184        Ok(BinaryArray::try_new(offsets, values, Some(nulls))?)
185    }
186
187    /// Deserialize a row id mask from Arrow
188    pub fn from_arrow(array: &GenericBinaryArray<i32>) -> Result<Self> {
189        let block_list = if array.is_null(0) {
190            None
191        } else {
192            Some(RowIdTreeMap::deserialize_from(array.value(0)))
193        }
194        .transpose()?;
195
196        let allow_list = if array.is_null(1) {
197            None
198        } else {
199            Some(RowIdTreeMap::deserialize_from(array.value(1)))
200        }
201        .transpose()?;
202        Ok(Self {
203            block_list,
204            allow_list,
205        })
206    }
207}
208
209impl std::ops::Not for RowIdMask {
210    type Output = Self;
211
212    fn not(self) -> Self::Output {
213        Self {
214            block_list: self.allow_list,
215            allow_list: self.block_list,
216        }
217    }
218}
219
220impl std::ops::BitAnd for RowIdMask {
221    type Output = Self;
222
223    fn bitand(self, rhs: Self) -> Self::Output {
224        let block_list = match (self.block_list, rhs.block_list) {
225            (None, None) => None,
226            (Some(lhs), None) => Some(lhs),
227            (None, Some(rhs)) => Some(rhs),
228            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
229        };
230        let allow_list = match (self.allow_list, rhs.allow_list) {
231            (None, None) => None,
232            (Some(lhs), None) => Some(lhs),
233            (None, Some(rhs)) => Some(rhs),
234            (Some(lhs), Some(rhs)) => Some(lhs & rhs),
235        };
236        Self {
237            block_list,
238            allow_list,
239        }
240    }
241}
242
243impl std::ops::BitOr for RowIdMask {
244    type Output = Self;
245
246    fn bitor(self, rhs: Self) -> Self::Output {
247        let block_list = match (self.block_list, rhs.block_list) {
248            (None, None) => None,
249            (Some(lhs), None) => Some(lhs),
250            (None, Some(rhs)) => Some(rhs),
251            (Some(lhs), Some(rhs)) => Some(lhs & rhs),
252        };
253        let allow_list = match (self.allow_list, rhs.allow_list) {
254            (None, None) => None,
255            // Remember that an allow list of None means "all rows" and
256            // so "all rows" | "some rows" is always "all rows"
257            (Some(_), None) => None,
258            (None, Some(_)) => None,
259            (Some(lhs), Some(rhs)) => Some(lhs | rhs),
260        };
261        Self {
262            block_list,
263            allow_list,
264        }
265    }
266}
267
268/// A collection of row ids.
269///
270/// These row ids may either be stable-style (where they can be an incrementing
271/// u64 sequence) or address style, where they are a fragment id and a row offset.
272/// When address style, this supports setting entire fragments as selected,
273/// without needing to enumerate all the ids in the fragment.
274///
275/// This is similar to a [RoaringTreemap] but it is optimized for the case where
276/// entire fragments are selected or deselected.
277#[derive(Clone, Debug, Default, PartialEq, DeepSizeOf)]
278pub struct RowIdTreeMap {
279    /// The contents of the set. If there is a pair (k, Full) then the entire
280    /// fragment k is selected. If there is a pair (k, Partial(v)) then the
281    /// fragment k has the selected rows in v.
282    inner: BTreeMap<u32, RowIdSelection>,
283}
284
285#[derive(Clone, Debug, PartialEq)]
286enum RowIdSelection {
287    Full,
288    Partial(RoaringBitmap),
289}
290
291impl DeepSizeOf for RowIdSelection {
292    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
293        match self {
294            Self::Full => 0,
295            Self::Partial(bitmap) => bitmap.serialized_size(),
296        }
297    }
298}
299
300impl RowIdSelection {
301    fn union_all(selections: &[&Self]) -> Self {
302        let mut is_full = false;
303
304        let res = Self::Partial(
305            selections
306                .iter()
307                .filter_map(|selection| match selection {
308                    Self::Full => {
309                        is_full = true;
310                        None
311                    }
312                    Self::Partial(bitmap) => Some(bitmap),
313                })
314                .union(),
315        );
316
317        if is_full {
318            Self::Full
319        } else {
320            res
321        }
322    }
323}
324
325impl RowIdTreeMap {
326    /// Create an empty set
327    pub fn new() -> Self {
328        Self::default()
329    }
330
331    pub fn is_empty(&self) -> bool {
332        self.inner.is_empty()
333    }
334
335    /// The number of rows in the map
336    ///
337    /// If there are any "full fragment" items then this is unknown and None is returned
338    pub fn len(&self) -> Option<u64> {
339        self.inner
340            .values()
341            .map(|row_id_selection| match row_id_selection {
342                RowIdSelection::Full => None,
343                RowIdSelection::Partial(indices) => Some(indices.len()),
344            })
345            .try_fold(0_u64, |acc, next| next.map(|next| next + acc))
346    }
347
348    /// An iterator of row ids
349    ///
350    /// If there are any "full fragment" items then this can't be calculated and None
351    /// is returned
352    pub fn row_ids(&self) -> Option<impl Iterator<Item = RowAddress> + '_> {
353        let inner_iters = self
354            .inner
355            .iter()
356            .filter_map(|(frag_id, row_id_selection)| match row_id_selection {
357                RowIdSelection::Full => None,
358                RowIdSelection::Partial(bitmap) => Some(
359                    bitmap
360                        .iter()
361                        .map(|row_offset| RowAddress::new_from_parts(*frag_id, row_offset)),
362                ),
363            })
364            .collect::<Vec<_>>();
365        if inner_iters.len() != self.inner.len() {
366            None
367        } else {
368            Some(inner_iters.into_iter().flatten())
369        }
370    }
371
372    /// Insert a single value into the set
373    ///
374    /// Returns true if the value was not already in the set.
375    ///
376    /// ```rust
377    /// use lance_core::utils::mask::RowIdTreeMap;
378    ///
379    /// let mut set = RowIdTreeMap::new();
380    /// assert_eq!(set.insert(10), true);
381    /// assert_eq!(set.insert(10), false);
382    /// assert_eq!(set.contains(10), true);
383    /// ```
384    pub fn insert(&mut self, value: u64) -> bool {
385        let fragment = (value >> 32) as u32;
386        let row_addr = value as u32;
387        match self.inner.get_mut(&fragment) {
388            None => {
389                let mut set = RoaringBitmap::new();
390                set.insert(row_addr);
391                self.inner.insert(fragment, RowIdSelection::Partial(set));
392                true
393            }
394            Some(RowIdSelection::Full) => false,
395            Some(RowIdSelection::Partial(set)) => set.insert(row_addr),
396        }
397    }
398
399    /// Insert a range of values into the set
400    pub fn insert_range<R: RangeBounds<u64>>(&mut self, range: R) -> u64 {
401        // Separate the start and end into high and low bits.
402        let (mut start_high, mut start_low) = match range.start_bound() {
403            std::ops::Bound::Included(&start) => ((start >> 32) as u32, start as u32),
404            std::ops::Bound::Excluded(&start) => {
405                let start = start.saturating_add(1);
406                ((start >> 32) as u32, start as u32)
407            }
408            std::ops::Bound::Unbounded => (0, 0),
409        };
410
411        let (end_high, end_low) = match range.end_bound() {
412            std::ops::Bound::Included(&end) => ((end >> 32) as u32, end as u32),
413            std::ops::Bound::Excluded(&end) => {
414                let end = end.saturating_sub(1);
415                ((end >> 32) as u32, end as u32)
416            }
417            std::ops::Bound::Unbounded => (u32::MAX, u32::MAX),
418        };
419
420        let mut count = 0;
421
422        while start_high <= end_high {
423            let start = start_low;
424            let end = if start_high == end_high {
425                end_low
426            } else {
427                u32::MAX
428            };
429            let fragment = start_high;
430            match self.inner.get_mut(&fragment) {
431                None => {
432                    let mut set = RoaringBitmap::new();
433                    count += set.insert_range(start..=end);
434                    self.inner.insert(fragment, RowIdSelection::Partial(set));
435                }
436                Some(RowIdSelection::Full) => {}
437                Some(RowIdSelection::Partial(set)) => {
438                    count += set.insert_range(start..=end);
439                }
440            }
441            start_high += 1;
442            start_low = 0;
443        }
444
445        count
446    }
447
448    /// Add a bitmap for a single fragment
449    pub fn insert_bitmap(&mut self, fragment: u32, bitmap: RoaringBitmap) {
450        self.inner.insert(fragment, RowIdSelection::Partial(bitmap));
451    }
452
453    /// Add a whole fragment to the set
454    pub fn insert_fragment(&mut self, fragment_id: u32) {
455        self.inner.insert(fragment_id, RowIdSelection::Full);
456    }
457
458    pub fn get_fragment_bitmap(&self, fragment_id: u32) -> Option<&RoaringBitmap> {
459        match self.inner.get(&fragment_id) {
460            None => None,
461            Some(RowIdSelection::Full) => None,
462            Some(RowIdSelection::Partial(set)) => Some(set),
463        }
464    }
465
466    /// Returns whether the set contains the given value
467    pub fn contains(&self, value: u64) -> bool {
468        let upper = (value >> 32) as u32;
469        let lower = value as u32;
470        match self.inner.get(&upper) {
471            None => false,
472            Some(RowIdSelection::Full) => true,
473            Some(RowIdSelection::Partial(fragment_set)) => fragment_set.contains(lower),
474        }
475    }
476
477    pub fn remove(&mut self, value: u64) -> bool {
478        let upper = (value >> 32) as u32;
479        let lower = value as u32;
480        match self.inner.get_mut(&upper) {
481            None => false,
482            Some(RowIdSelection::Full) => {
483                let mut set = RoaringBitmap::full();
484                set.remove(lower);
485                self.inner.insert(upper, RowIdSelection::Partial(set));
486                true
487            }
488            Some(RowIdSelection::Partial(lower_set)) => {
489                let removed = lower_set.remove(lower);
490                if lower_set.is_empty() {
491                    self.inner.remove(&upper);
492                }
493                removed
494            }
495        }
496    }
497
498    pub fn retain_fragments(&mut self, frag_ids: impl IntoIterator<Item = u32>) {
499        let frag_id_set = frag_ids.into_iter().collect::<HashSet<_>>();
500        self.inner
501            .retain(|frag_id, _| frag_id_set.contains(frag_id));
502    }
503
504    /// Compute the serialized size of the set.
505    pub fn serialized_size(&self) -> usize {
506        // Starts at 4 because of the u32 num_entries
507        let mut size = 4;
508        for set in self.inner.values() {
509            // Each entry is 8 bytes for the fragment id and the bitmap size
510            size += 8;
511            if let RowIdSelection::Partial(set) = set {
512                size += set.serialized_size();
513            }
514        }
515        size
516    }
517
518    /// Serialize the set into the given buffer
519    ///
520    /// The serialization format is not stable.
521    ///
522    /// The serialization format is:
523    /// * u32: num_entries
524    ///
525    /// for each entry:
526    ///   * u32: fragment_id
527    ///   * u32: bitmap size
528    ///   * \[u8\]: bitmap
529    ///
530    /// If bitmap size is zero then the entire fragment is selected.
531    pub fn serialize_into<W: Write>(&self, mut writer: W) -> Result<()> {
532        writer.write_u32::<byteorder::LittleEndian>(self.inner.len() as u32)?;
533        for (fragment, set) in &self.inner {
534            writer.write_u32::<byteorder::LittleEndian>(*fragment)?;
535            if let RowIdSelection::Partial(set) = set {
536                writer.write_u32::<byteorder::LittleEndian>(set.serialized_size() as u32)?;
537                set.serialize_into(&mut writer)?;
538            } else {
539                writer.write_u32::<byteorder::LittleEndian>(0)?;
540            }
541        }
542        Ok(())
543    }
544
545    /// Deserialize the set from the given buffer
546    pub fn deserialize_from<R: Read>(mut reader: R) -> Result<Self> {
547        let num_entries = reader.read_u32::<byteorder::LittleEndian>()?;
548        let mut inner = BTreeMap::new();
549        for _ in 0..num_entries {
550            let fragment = reader.read_u32::<byteorder::LittleEndian>()?;
551            let bitmap_size = reader.read_u32::<byteorder::LittleEndian>()?;
552            if bitmap_size == 0 {
553                inner.insert(fragment, RowIdSelection::Full);
554            } else {
555                let mut buffer = vec![0; bitmap_size as usize];
556                reader.read_exact(&mut buffer)?;
557                let set = RoaringBitmap::deserialize_from(&buffer[..])?;
558                inner.insert(fragment, RowIdSelection::Partial(set));
559            }
560        }
561        Ok(Self { inner })
562    }
563
564    pub fn union_all(maps: &[&Self]) -> Self {
565        let mut new_map = BTreeMap::new();
566
567        for map in maps {
568            for (fragment, selection) in &map.inner {
569                new_map
570                    .entry(fragment)
571                    // I hate this allocation, but I can't think of a better way
572                    .or_insert_with(|| Vec::with_capacity(maps.len()))
573                    .push(selection);
574            }
575        }
576
577        let new_map = new_map
578            .into_iter()
579            .map(|(&fragment, selections)| (fragment, RowIdSelection::union_all(&selections)))
580            .collect();
581
582        Self { inner: new_map }
583    }
584
585    /// Apply a mask to the row ids
586    ///
587    /// If there is an allow list then this will intersect the set with the allow list
588    /// If there is a block list then this will subtract the block list from the set
589    pub fn mask(&mut self, mask: &RowIdMask) {
590        if let Some(allow_list) = &mask.allow_list {
591            *self &= allow_list;
592        }
593        if let Some(block_list) = &mask.block_list {
594            *self -= block_list;
595        }
596    }
597
598    /// Convert the set into an iterator of row ids
599    ///
600    /// # Safety
601    ///
602    /// This is unsafe because if any of the inner RowIdSelection elements
603    /// is not a Partial then the iterator will panic because we don't know
604    /// the size of the bitmap.
605    pub unsafe fn into_id_iter(self) -> impl Iterator<Item = u64> {
606        self.inner
607            .into_iter()
608            .flat_map(|(fragment, selection)| match selection {
609                RowIdSelection::Full => panic!("Size of full fragment is unknown"),
610                RowIdSelection::Partial(bitmap) => bitmap.into_iter().map(move |val| {
611                    let fragment = fragment as u64;
612                    let row_offset = val as u64;
613                    (fragment << 32) | row_offset
614                }),
615            })
616    }
617}
618
619impl std::ops::BitOr<Self> for RowIdTreeMap {
620    type Output = Self;
621
622    fn bitor(mut self, rhs: Self) -> Self::Output {
623        self |= rhs;
624        self
625    }
626}
627
628impl std::ops::BitOrAssign<Self> for RowIdTreeMap {
629    fn bitor_assign(&mut self, rhs: Self) {
630        for (fragment, rhs_set) in &rhs.inner {
631            let lhs_set = self.inner.get_mut(fragment);
632            if let Some(lhs_set) = lhs_set {
633                match lhs_set {
634                    RowIdSelection::Full => {
635                        // If the fragment is already selected then there is nothing to do
636                    }
637                    RowIdSelection::Partial(lhs_bitmap) => match rhs_set {
638                        RowIdSelection::Full => {
639                            *lhs_set = RowIdSelection::Full;
640                        }
641                        RowIdSelection::Partial(rhs_set) => {
642                            *lhs_bitmap |= rhs_set;
643                        }
644                    },
645                }
646            } else {
647                self.inner.insert(*fragment, rhs_set.clone());
648            }
649        }
650    }
651}
652
653impl std::ops::BitAnd<Self> for RowIdTreeMap {
654    type Output = Self;
655
656    fn bitand(mut self, rhs: Self) -> Self::Output {
657        self &= &rhs;
658        self
659    }
660}
661
662impl std::ops::BitAndAssign<&Self> for RowIdTreeMap {
663    fn bitand_assign(&mut self, rhs: &Self) {
664        // Remove fragment that aren't on the RHS
665        self.inner
666            .retain(|fragment, _| rhs.inner.contains_key(fragment));
667
668        // For fragments that are on the RHS, intersect the bitmaps
669        for (fragment, mut lhs_set) in &mut self.inner {
670            match (&mut lhs_set, rhs.inner.get(fragment)) {
671                (_, None) => {} // Already handled by retain
672                (_, Some(RowIdSelection::Full)) => {
673                    // Everything selected on RHS, so can leave LHS untouched.
674                }
675                (RowIdSelection::Partial(lhs_set), Some(RowIdSelection::Partial(rhs_set))) => {
676                    *lhs_set &= rhs_set;
677                }
678                (RowIdSelection::Full, Some(RowIdSelection::Partial(rhs_set))) => {
679                    *lhs_set = RowIdSelection::Partial(rhs_set.clone());
680                }
681            }
682        }
683        // Some bitmaps might now be empty. If they are, we should remove them.
684        self.inner.retain(|_, set| match set {
685            RowIdSelection::Partial(set) => !set.is_empty(),
686            RowIdSelection::Full => true,
687        });
688    }
689}
690
691impl std::ops::SubAssign<&Self> for RowIdTreeMap {
692    fn sub_assign(&mut self, rhs: &Self) {
693        for (fragment, rhs_set) in &rhs.inner {
694            match self.inner.get_mut(fragment) {
695                None => {}
696                Some(RowIdSelection::Full) => {
697                    // If the fragment is already selected then there is nothing to do
698                    match rhs_set {
699                        RowIdSelection::Full => {
700                            self.inner.remove(fragment);
701                        }
702                        RowIdSelection::Partial(rhs_set) => {
703                            // This generally won't be hit.
704                            let mut set = RoaringBitmap::full();
705                            set -= rhs_set;
706                            self.inner.insert(*fragment, RowIdSelection::Partial(set));
707                        }
708                    }
709                }
710                Some(RowIdSelection::Partial(lhs_set)) => match rhs_set {
711                    RowIdSelection::Full => {
712                        self.inner.remove(fragment);
713                    }
714                    RowIdSelection::Partial(rhs_set) => {
715                        *lhs_set -= rhs_set;
716                        if lhs_set.is_empty() {
717                            self.inner.remove(fragment);
718                        }
719                    }
720                },
721            }
722        }
723    }
724}
725
726impl FromIterator<u64> for RowIdTreeMap {
727    fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> Self {
728        let mut inner = BTreeMap::new();
729        for row_id in iter {
730            let upper = (row_id >> 32) as u32;
731            let lower = row_id as u32;
732            match inner.get_mut(&upper) {
733                None => {
734                    let mut set = RoaringBitmap::new();
735                    set.insert(lower);
736                    inner.insert(upper, RowIdSelection::Partial(set));
737                }
738                Some(RowIdSelection::Full) => {
739                    // If the fragment is already selected then there is nothing to do
740                }
741                Some(RowIdSelection::Partial(set)) => {
742                    set.insert(lower);
743                }
744            }
745        }
746        Self { inner }
747    }
748}
749
750impl<'a> FromIterator<&'a u64> for RowIdTreeMap {
751    fn from_iter<T: IntoIterator<Item = &'a u64>>(iter: T) -> Self {
752        Self::from_iter(iter.into_iter().copied())
753    }
754}
755
756impl From<Range<u64>> for RowIdTreeMap {
757    fn from(range: Range<u64>) -> Self {
758        let mut map = Self::default();
759        map.insert_range(range);
760        map
761    }
762}
763
764impl From<RoaringTreemap> for RowIdTreeMap {
765    fn from(roaring: RoaringTreemap) -> Self {
766        let mut inner = BTreeMap::new();
767        for (fragment, set) in roaring.bitmaps() {
768            inner.insert(fragment, RowIdSelection::Partial(set.clone()));
769        }
770        Self { inner }
771    }
772}
773
774impl Extend<u64> for RowIdTreeMap {
775    fn extend<T: IntoIterator<Item = u64>>(&mut self, iter: T) {
776        for row_id in iter {
777            let upper = (row_id >> 32) as u32;
778            let lower = row_id as u32;
779            match self.inner.get_mut(&upper) {
780                None => {
781                    let mut set = RoaringBitmap::new();
782                    set.insert(lower);
783                    self.inner.insert(upper, RowIdSelection::Partial(set));
784                }
785                Some(RowIdSelection::Full) => {
786                    // If the fragment is already selected then there is nothing to do
787                }
788                Some(RowIdSelection::Partial(set)) => {
789                    set.insert(lower);
790                }
791            }
792        }
793    }
794}
795
796impl<'a> Extend<&'a u64> for RowIdTreeMap {
797    fn extend<T: IntoIterator<Item = &'a u64>>(&mut self, iter: T) {
798        self.extend(iter.into_iter().copied())
799    }
800}
801
802// Extending with RowIdTreeMap is basically a cumulative set union
803impl Extend<Self> for RowIdTreeMap {
804    fn extend<T: IntoIterator<Item = Self>>(&mut self, iter: T) {
805        for other in iter {
806            for (fragment, set) in other.inner {
807                match self.inner.get_mut(&fragment) {
808                    None => {
809                        self.inner.insert(fragment, set);
810                    }
811                    Some(RowIdSelection::Full) => {
812                        // If the fragment is already selected then there is nothing to do
813                    }
814                    Some(RowIdSelection::Partial(lhs_set)) => match set {
815                        RowIdSelection::Full => {
816                            self.inner.insert(fragment, RowIdSelection::Full);
817                        }
818                        RowIdSelection::Partial(rhs_set) => {
819                            *lhs_set |= rhs_set;
820                        }
821                    },
822                }
823            }
824        }
825    }
826}
827
828#[cfg(test)]
829mod tests {
830    use super::*;
831    use proptest::prop_assert_eq;
832
833    #[test]
834    fn test_ops() {
835        let mask = RowIdMask::default();
836        assert!(mask.selected(1));
837        assert!(mask.selected(5));
838        let block_list = mask.also_block(RowIdTreeMap::from_iter(&[0, 5, 15]));
839        assert!(block_list.selected(1));
840        assert!(!block_list.selected(5));
841        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[0, 2, 5]));
842        assert!(!allow_list.selected(1));
843        assert!(allow_list.selected(5));
844        let combined = block_list & allow_list;
845        assert!(combined.selected(2));
846        assert!(!combined.selected(0));
847        assert!(!combined.selected(5));
848        let other = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
849        let combined = combined | other;
850        assert!(combined.selected(2));
851        assert!(combined.selected(3));
852        assert!(!combined.selected(0));
853        assert!(!combined.selected(5));
854
855        let block_list = RowIdMask::from_block(RowIdTreeMap::from_iter(&[0]));
856        let allow_list = RowIdMask::from_allowed(RowIdTreeMap::from_iter(&[3]));
857        let combined = block_list | allow_list;
858        assert!(combined.selected(1));
859    }
860
861    #[test]
862    fn test_map_insert_range() {
863        let ranges = &[
864            (0..10),
865            (40..500),
866            ((u32::MAX as u64 - 10)..(u32::MAX as u64 + 20)),
867        ];
868
869        for range in ranges {
870            let mut mask = RowIdTreeMap::default();
871
872            let count = mask.insert_range(range.clone());
873            let expected = range.end - range.start;
874            assert_eq!(count, expected);
875
876            let count = mask.insert_range(range.clone());
877            assert_eq!(count, 0);
878
879            let new_range = range.start + 5..range.end + 5;
880            let count = mask.insert_range(new_range.clone());
881            assert_eq!(count, 5);
882        }
883
884        let mut mask = RowIdTreeMap::default();
885        let count = mask.insert_range(..10);
886        assert_eq!(count, 10);
887        assert!(mask.contains(0));
888
889        let count = mask.insert_range(20..=24);
890        assert_eq!(count, 5);
891
892        mask.insert_fragment(0);
893        let count = mask.insert_range(100..200);
894        assert_eq!(count, 0);
895    }
896
897    #[test]
898    fn test_map_remove() {
899        let mut mask = RowIdTreeMap::default();
900
901        assert!(!mask.remove(20));
902
903        mask.insert(20);
904        assert!(mask.contains(20));
905        assert!(mask.remove(20));
906        assert!(!mask.contains(20));
907
908        mask.insert_range(10..=20);
909        assert!(mask.contains(15));
910        assert!(mask.remove(15));
911        assert!(!mask.contains(15));
912
913        // We don't test removing from a full fragment, because that would take
914        // a lot of memory.
915    }
916
917    proptest::proptest! {
918        #[test]
919        fn test_map_serialization_roundtrip(
920            values in proptest::collection::vec(
921                (0..u32::MAX, proptest::option::of(proptest::collection::vec(0..u32::MAX, 0..1000))),
922                0..10
923            )
924        ) {
925            let mut mask = RowIdTreeMap::default();
926            for (fragment, rows) in values {
927                if let Some(rows) = rows {
928                    let bitmap = RoaringBitmap::from_iter(rows);
929                    mask.insert_bitmap(fragment, bitmap);
930                } else {
931                    mask.insert_fragment(fragment);
932                }
933            }
934
935            let mut data = Vec::new();
936            mask.serialize_into(&mut data).unwrap();
937            let deserialized = RowIdTreeMap::deserialize_from(data.as_slice()).unwrap();
938            prop_assert_eq!(mask, deserialized);
939        }
940
941        #[test]
942        fn test_map_intersect(
943            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
944            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
945            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
946            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
947        ) {
948            let mut left = RowIdTreeMap::default();
949            for fragment in left_full_fragments.clone() {
950                left.insert_fragment(fragment);
951            }
952            left.extend(left_rows.iter().copied());
953
954            let mut right = RowIdTreeMap::default();
955            for fragment in right_full_fragments.clone() {
956                right.insert_fragment(fragment);
957            }
958            right.extend(right_rows.iter().copied());
959
960            let mut expected = RowIdTreeMap::default();
961            for fragment in &left_full_fragments {
962                if right_full_fragments.contains(fragment) {
963                    expected.insert_fragment(*fragment);
964                }
965            }
966
967            let left_in_right = left_rows.iter().filter(|row| {
968                right_rows.contains(row)
969                    || right_full_fragments.contains(&((*row >> 32) as u32))
970            });
971            expected.extend(left_in_right);
972            let right_in_left = right_rows.iter().filter(|row| {
973                left_rows.contains(row)
974                    || left_full_fragments.contains(&((*row >> 32) as u32))
975            });
976            expected.extend(right_in_left);
977
978            let actual = left & right;
979            prop_assert_eq!(expected, actual);
980        }
981
982        #[test]
983        fn test_map_union(
984            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
985            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
986            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
987            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
988        ) {
989            let mut left = RowIdTreeMap::default();
990            for fragment in left_full_fragments.clone() {
991                left.insert_fragment(fragment);
992            }
993            left.extend(left_rows.iter().copied());
994
995            let mut right = RowIdTreeMap::default();
996            for fragment in right_full_fragments.clone() {
997                right.insert_fragment(fragment);
998            }
999            right.extend(right_rows.iter().copied());
1000
1001            let mut expected = RowIdTreeMap::default();
1002            for fragment in left_full_fragments {
1003                expected.insert_fragment(fragment);
1004            }
1005            for fragment in right_full_fragments {
1006                expected.insert_fragment(fragment);
1007            }
1008
1009            let combined_rows = left_rows.iter().chain(right_rows.iter());
1010            expected.extend(combined_rows);
1011
1012            let actual = left | right;
1013            for actual_key_val in &actual.inner {
1014                proptest::prop_assert!(expected.inner.contains_key(actual_key_val.0));
1015                let expected_val = expected.inner.get(actual_key_val.0).unwrap();
1016                prop_assert_eq!(
1017                    actual_key_val.1,
1018                    expected_val,
1019                    "error on key {}",
1020                    actual_key_val.0
1021                );
1022            }
1023            prop_assert_eq!(expected, actual);
1024        }
1025
1026        #[test]
1027        fn test_map_subassign_rows(
1028            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1029            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1030            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1031        ) {
1032            let mut left = RowIdTreeMap::default();
1033            for fragment in left_full_fragments {
1034                left.insert_fragment(fragment);
1035            }
1036            left.extend(left_rows.iter().copied());
1037
1038            let mut right = RowIdTreeMap::default();
1039            right.extend(right_rows.iter().copied());
1040
1041            let mut expected = left.clone();
1042            for row in right_rows {
1043                expected.remove(row);
1044            }
1045
1046            left -= &right;
1047            prop_assert_eq!(expected, left);
1048        }
1049
1050        #[test]
1051        fn test_map_subassign_frags(
1052            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1053            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1054            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1055        ) {
1056            let mut left = RowIdTreeMap::default();
1057            for fragment in left_full_fragments {
1058                left.insert_fragment(fragment);
1059            }
1060            left.extend(left_rows.iter().copied());
1061
1062            let mut right = RowIdTreeMap::default();
1063            for fragment in right_full_fragments.clone() {
1064                right.insert_fragment(fragment);
1065            }
1066
1067            let mut expected = left.clone();
1068            for fragment in right_full_fragments {
1069                expected.inner.remove(&fragment);
1070            }
1071
1072            left -= &right;
1073            prop_assert_eq!(expected, left);
1074        }
1075    }
1076}