Skip to main content

lance_core/utils/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashSet;
5use std::io::Write;
6use std::ops::{Range, RangeBounds, RangeInclusive};
7use std::{collections::BTreeMap, io::Read};
8
9use arrow_array::{Array, BinaryArray, GenericBinaryArray};
10use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer};
11use byteorder::{ReadBytesExt, WriteBytesExt};
12use deepsize::DeepSizeOf;
13use itertools::Itertools;
14use roaring::{MultiOps, RoaringBitmap, RoaringTreemap};
15
16use crate::error::ToSnafuLocation;
17use crate::{Error, Result};
18
19use super::address::RowAddress;
20
21mod nullable;
22
23pub use nullable::{NullableRowAddrMask, NullableRowAddrSet};
24
25/// A mask that selects or deselects rows based on an allow-list or block-list.
26#[derive(Clone, Debug, DeepSizeOf, PartialEq)]
27pub enum RowAddrMask {
28    AllowList(RowAddrTreeMap),
29    BlockList(RowAddrTreeMap),
30}
31
32impl Default for RowAddrMask {
33    fn default() -> Self {
34        // Empty block list means all rows are allowed
35        Self::BlockList(RowAddrTreeMap::new())
36    }
37}
38
39impl RowAddrMask {
40    // Create a mask allowing all rows, this is an alias for [default]
41    pub fn all_rows() -> Self {
42        Self::default()
43    }
44
45    // Create a mask that doesn't allow anything
46    pub fn allow_nothing() -> Self {
47        Self::AllowList(RowAddrTreeMap::new())
48    }
49
50    // Create a mask from an allow list
51    pub fn from_allowed(allow_list: RowAddrTreeMap) -> Self {
52        Self::AllowList(allow_list)
53    }
54
55    // Create a mask from a block list
56    pub fn from_block(block_list: RowAddrTreeMap) -> Self {
57        Self::BlockList(block_list)
58    }
59
60    pub fn block_list(&self) -> Option<&RowAddrTreeMap> {
61        match self {
62            Self::BlockList(block_list) => Some(block_list),
63            _ => None,
64        }
65    }
66
67    pub fn allow_list(&self) -> Option<&RowAddrTreeMap> {
68        match self {
69            Self::AllowList(allow_list) => Some(allow_list),
70            _ => None,
71        }
72    }
73
74    /// True if the row_id is selected by the mask, false otherwise
75    pub fn selected(&self, row_id: u64) -> bool {
76        match self {
77            Self::AllowList(allow_list) => allow_list.contains(row_id),
78            Self::BlockList(block_list) => !block_list.contains(row_id),
79        }
80    }
81
82    /// Return the indices of the input row ids that were valid
83    pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
84        row_ids
85            .enumerate()
86            .filter_map(|(idx, row_id)| {
87                if self.selected(*row_id) {
88                    Some(idx as u64)
89                } else {
90                    None
91                }
92            })
93            .collect()
94    }
95
96    /// Also block the given addrs
97    pub fn also_block(self, block_list: RowAddrTreeMap) -> Self {
98        match self {
99            Self::AllowList(allow_list) => Self::AllowList(allow_list - block_list),
100            Self::BlockList(existing) => Self::BlockList(existing | block_list),
101        }
102    }
103
104    /// Also allow the given addrs
105    pub fn also_allow(self, allow_list: RowAddrTreeMap) -> Self {
106        match self {
107            Self::AllowList(existing) => Self::AllowList(existing | allow_list),
108            Self::BlockList(block_list) => Self::BlockList(block_list - allow_list),
109        }
110    }
111
112    /// Convert a mask into an arrow array
113    ///
114    /// A row addr mask is not very arrow-compatible.  We can't make it a batch with
115    /// two columns because the block list and allow list will have different lengths.  Also,
116    /// there is no Arrow type for compressed bitmaps.
117    ///
118    /// However, we need to shove it into some kind of Arrow container to pass it along the
119    /// datafusion stream.  Perhaps, in the future, we can add row addr masks as first class
120    /// types in datafusion, and this can be passed along as a mask / selection vector.
121    ///
122    /// We serialize this as a variable length binary array with two items.  The first item
123    /// is the block list and the second item is the allow list.
124    pub fn into_arrow(&self) -> Result<BinaryArray> {
125        // NOTE: This serialization format must be stable as it is used in IPC.
126        let (block_list, allow_list) = match self {
127            Self::AllowList(allow_list) => (None, Some(allow_list)),
128            Self::BlockList(block_list) => (Some(block_list), None),
129        };
130
131        let block_list_length = block_list
132            .as_ref()
133            .map(|bl| bl.serialized_size())
134            .unwrap_or(0);
135        let allow_list_length = allow_list
136            .as_ref()
137            .map(|al| al.serialized_size())
138            .unwrap_or(0);
139        let lengths = vec![block_list_length, allow_list_length];
140        let offsets = OffsetBuffer::from_lengths(lengths);
141        let mut value_bytes = vec![0; block_list_length + allow_list_length];
142        let mut validity = vec![false, false];
143        if let Some(block_list) = &block_list {
144            validity[0] = true;
145            block_list.serialize_into(&mut value_bytes[0..])?;
146        }
147        if let Some(allow_list) = &allow_list {
148            validity[1] = true;
149            allow_list.serialize_into(&mut value_bytes[block_list_length..])?;
150        }
151        let values = Buffer::from(value_bytes);
152        let nulls = NullBuffer::from(validity);
153        Ok(BinaryArray::try_new(offsets, values, Some(nulls))?)
154    }
155
156    /// Deserialize a row address mask from Arrow
157    pub fn from_arrow(array: &GenericBinaryArray<i32>) -> Result<Self> {
158        let block_list = if array.is_null(0) {
159            None
160        } else {
161            Some(RowAddrTreeMap::deserialize_from(array.value(0)))
162        }
163        .transpose()?;
164
165        let allow_list = if array.is_null(1) {
166            None
167        } else {
168            Some(RowAddrTreeMap::deserialize_from(array.value(1)))
169        }
170        .transpose()?;
171
172        let res = match (block_list, allow_list) {
173            (Some(bl), None) => Self::BlockList(bl),
174            (None, Some(al)) => Self::AllowList(al),
175            (Some(block), Some(allow)) => Self::AllowList(allow).also_block(block),
176            (None, None) => Self::all_rows(),
177        };
178        Ok(res)
179    }
180
181    /// Return the maximum number of row addresses that could be selected by this mask
182    ///
183    /// Will be None if this is a BlockList (unbounded)
184    pub fn max_len(&self) -> Option<u64> {
185        match self {
186            Self::AllowList(selection) => selection.len(),
187            Self::BlockList(_) => None,
188        }
189    }
190
191    /// Iterate over the row addresses that are selected by the mask
192    ///
193    /// This is only possible if this is an AllowList and the maps don't contain
194    /// any "full fragment" blocks.
195    pub fn iter_addrs(&self) -> Option<Box<dyn Iterator<Item = RowAddress> + '_>> {
196        match self {
197            Self::AllowList(allow_list) => {
198                if let Some(allow_iter) = allow_list.row_addrs() {
199                    Some(Box::new(allow_iter))
200                } else {
201                    None
202                }
203            }
204            Self::BlockList(_) => None, // Can't iterate over block list
205        }
206    }
207}
208
209impl std::ops::Not for RowAddrMask {
210    type Output = Self;
211
212    fn not(self) -> Self::Output {
213        match self {
214            Self::AllowList(allow_list) => Self::BlockList(allow_list),
215            Self::BlockList(block_list) => Self::AllowList(block_list),
216        }
217    }
218}
219
220impl std::ops::BitAnd for RowAddrMask {
221    type Output = Self;
222
223    fn bitand(self, rhs: Self) -> Self::Output {
224        match (self, rhs) {
225            (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a & b),
226            (Self::AllowList(allow), Self::BlockList(block))
227            | (Self::BlockList(block), Self::AllowList(allow)) => Self::AllowList(allow - block),
228            (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a | b),
229        }
230    }
231}
232
233impl std::ops::BitOr for RowAddrMask {
234    type Output = Self;
235
236    fn bitor(self, rhs: Self) -> Self::Output {
237        match (self, rhs) {
238            (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a | b),
239            (Self::AllowList(allow), Self::BlockList(block))
240            | (Self::BlockList(block), Self::AllowList(allow)) => Self::BlockList(block - allow),
241            (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a & b),
242        }
243    }
244}
245
246/// Common operations over a set of rows (either row ids or row addresses).
247///
248/// The concrete representation can be address-based (`RowAddrTreeMap`) or
249/// id-based (for example a future `RowIdSet`), but the semantics are the same:
250/// a set of unique rows.
251pub trait RowSetOps: Clone + Sized {
252    /// Logical row handle (`u64` for both row ids and row addresses).
253    type Row;
254
255    /// Returns true if the set is empty.
256    fn is_empty(&self) -> bool;
257
258    /// Returns the number of rows in the set, if it is known.
259    ///
260    /// Implementations that cannot always compute an exact size (for example
261    /// because of "full fragment" markers) should return `None`.
262    fn len(&self) -> Option<u64>;
263
264    /// Remove a value from the row set.
265    fn remove(&mut self, row: Self::Row) -> bool;
266
267    /// Returns whether this set contains the given row.
268    fn contains(&self, row: Self::Row) -> bool;
269
270    /// Returns the union of `other` and init self.
271    fn union_all(other: &[&Self]) -> Self;
272
273    /// Builds a row set from an iterator of rows.
274    fn from_sorted_iter<I>(iter: I) -> Result<Self>
275    where
276        I: IntoIterator<Item = Self::Row>;
277}
278
279/// A collection of row addresses.
280///
281/// Note: For stable row id mode, this may be split into a separate structure in the future.
282///
283/// These row ids may either be stable-style (where they can be an incrementing
284/// u64 sequence) or address style, where they are a fragment id and a row offset.
285/// When address style, this supports setting entire fragments as selected,
286/// without needing to enumerate all the ids in the fragment.
287///
288/// This is similar to a [RoaringTreemap] but it is optimized for the case where
289/// entire fragments are selected or deselected.
290#[derive(Clone, Debug, Default, PartialEq, DeepSizeOf)]
291pub struct RowAddrTreeMap {
292    /// The contents of the set. If there is a pair (k, Full) then the entire
293    /// fragment k is selected. If there is a pair (k, Partial(v)) then the
294    /// fragment k has the selected rows in v.
295    inner: BTreeMap<u32, RowAddrSelection>,
296}
297
298#[derive(Clone, Debug, PartialEq)]
299enum RowAddrSelection {
300    Full,
301    Partial(RoaringBitmap),
302}
303
304impl DeepSizeOf for RowAddrSelection {
305    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
306        match self {
307            Self::Full => 0,
308            Self::Partial(bitmap) => bitmap.serialized_size(),
309        }
310    }
311}
312
313impl RowAddrSelection {
314    fn union_all(selections: &[&Self]) -> Self {
315        let mut is_full = false;
316
317        let res = Self::Partial(
318            selections
319                .iter()
320                .filter_map(|selection| match selection {
321                    Self::Full => {
322                        is_full = true;
323                        None
324                    }
325                    Self::Partial(bitmap) => Some(bitmap),
326                })
327                .union(),
328        );
329
330        if is_full {
331            Self::Full
332        } else {
333            res
334        }
335    }
336}
337
338impl RowSetOps for RowAddrTreeMap {
339    type Row = u64;
340
341    fn is_empty(&self) -> bool {
342        self.inner.is_empty()
343    }
344
345    fn len(&self) -> Option<u64> {
346        self.inner
347            .values()
348            .map(|row_addr_selection| match row_addr_selection {
349                RowAddrSelection::Full => None,
350                RowAddrSelection::Partial(indices) => Some(indices.len()),
351            })
352            .try_fold(0_u64, |acc, next| next.map(|next| next + acc))
353    }
354
355    fn remove(&mut self, row: Self::Row) -> bool {
356        let upper = (row >> 32) as u32;
357        let lower = row as u32;
358        match self.inner.get_mut(&upper) {
359            None => false,
360            Some(RowAddrSelection::Full) => {
361                let mut set = RoaringBitmap::full();
362                set.remove(lower);
363                self.inner.insert(upper, RowAddrSelection::Partial(set));
364                true
365            }
366            Some(RowAddrSelection::Partial(lower_set)) => {
367                let removed = lower_set.remove(lower);
368                if lower_set.is_empty() {
369                    self.inner.remove(&upper);
370                }
371                removed
372            }
373        }
374    }
375
376    fn contains(&self, row: Self::Row) -> bool {
377        let upper = (row >> 32) as u32;
378        let lower = row as u32;
379        match self.inner.get(&upper) {
380            None => false,
381            Some(RowAddrSelection::Full) => true,
382            Some(RowAddrSelection::Partial(fragment_set)) => fragment_set.contains(lower),
383        }
384    }
385
386    fn union_all(other: &[&Self]) -> Self {
387        let mut new_map = BTreeMap::new();
388
389        for map in other {
390            for (fragment, selection) in &map.inner {
391                new_map
392                    .entry(fragment)
393                    // I hate this allocation, but I can't think of a better way
394                    .or_insert_with(|| Vec::with_capacity(other.len()))
395                    .push(selection);
396            }
397        }
398
399        let new_map = new_map
400            .into_iter()
401            .map(|(&fragment, selections)| (fragment, RowAddrSelection::union_all(&selections)))
402            .collect();
403
404        Self { inner: new_map }
405    }
406
407    #[track_caller]
408    fn from_sorted_iter<I>(iter: I) -> Result<Self>
409    where
410        I: IntoIterator<Item = Self::Row>,
411    {
412        let mut iter = iter.into_iter().peekable();
413        let mut inner = BTreeMap::new();
414
415        while let Some(row_id) = iter.peek() {
416            let fragment_id = (row_id >> 32) as u32;
417            let next_bitmap_iter = iter
418                .peeking_take_while(|row_id| (row_id >> 32) as u32 == fragment_id)
419                .map(|row_id| row_id as u32);
420            let Ok(bitmap) = RoaringBitmap::from_sorted_iter(next_bitmap_iter) else {
421                return Err(Error::Internal {
422                    message: "RowAddrTreeMap::from_sorted_iter called with non-sorted input"
423                        .to_string(),
424                    // Use the caller location since we aren't the one that got it out of order
425                    location: std::panic::Location::caller().to_snafu_location(),
426                });
427            };
428            inner.insert(fragment_id, RowAddrSelection::Partial(bitmap));
429        }
430
431        Ok(Self { inner })
432    }
433}
434
435impl RowAddrTreeMap {
436    /// Create an empty set
437    pub fn new() -> Self {
438        Self::default()
439    }
440
441    /// An iterator of row addrs
442    ///
443    /// If there are any "full fragment" items then this can't be calculated and None
444    /// is returned
445    pub fn row_addrs(&self) -> Option<impl Iterator<Item = RowAddress> + '_> {
446        let inner_iters = self
447            .inner
448            .iter()
449            .filter_map(|(frag_id, row_addr_selection)| match row_addr_selection {
450                RowAddrSelection::Full => None,
451                RowAddrSelection::Partial(bitmap) => Some(
452                    bitmap
453                        .iter()
454                        .map(|row_offset| RowAddress::new_from_parts(*frag_id, row_offset)),
455                ),
456            })
457            .collect::<Vec<_>>();
458        if inner_iters.len() != self.inner.len() {
459            None
460        } else {
461            Some(inner_iters.into_iter().flatten())
462        }
463    }
464
465    /// Insert a single value into the set
466    ///
467    /// Returns true if the value was not already in the set.
468    ///
469    /// ```rust
470    /// use lance_core::utils::mask::{RowAddrTreeMap, RowSetOps};
471    ///
472    /// let mut set = RowAddrTreeMap::new();
473    /// assert_eq!(set.insert(10), true);
474    /// assert_eq!(set.insert(10), false);
475    /// assert_eq!(set.contains(10), true);
476    /// ```
477    pub fn insert(&mut self, value: u64) -> bool {
478        let fragment = (value >> 32) as u32;
479        let row_addr = value as u32;
480        match self.inner.get_mut(&fragment) {
481            None => {
482                let mut set = RoaringBitmap::new();
483                set.insert(row_addr);
484                self.inner.insert(fragment, RowAddrSelection::Partial(set));
485                true
486            }
487            Some(RowAddrSelection::Full) => false,
488            Some(RowAddrSelection::Partial(set)) => set.insert(row_addr),
489        }
490    }
491
492    /// Insert a range of values into the set
493    pub fn insert_range<R: RangeBounds<u64>>(&mut self, range: R) -> u64 {
494        // Separate the start and end into high and low bits.
495        let (mut start_high, mut start_low) = match range.start_bound() {
496            std::ops::Bound::Included(&start) => ((start >> 32) as u32, start as u32),
497            std::ops::Bound::Excluded(&start) => {
498                let start = start.saturating_add(1);
499                ((start >> 32) as u32, start as u32)
500            }
501            std::ops::Bound::Unbounded => (0, 0),
502        };
503
504        let (end_high, end_low) = match range.end_bound() {
505            std::ops::Bound::Included(&end) => ((end >> 32) as u32, end as u32),
506            std::ops::Bound::Excluded(&end) => {
507                let end = end.saturating_sub(1);
508                ((end >> 32) as u32, end as u32)
509            }
510            std::ops::Bound::Unbounded => (u32::MAX, u32::MAX),
511        };
512
513        let mut count = 0;
514
515        while start_high <= end_high {
516            let start = start_low;
517            let end = if start_high == end_high {
518                end_low
519            } else {
520                u32::MAX
521            };
522            let fragment = start_high;
523            match self.inner.get_mut(&fragment) {
524                None => {
525                    let mut set = RoaringBitmap::new();
526                    count += set.insert_range(start..=end);
527                    self.inner.insert(fragment, RowAddrSelection::Partial(set));
528                }
529                Some(RowAddrSelection::Full) => {}
530                Some(RowAddrSelection::Partial(set)) => {
531                    count += set.insert_range(start..=end);
532                }
533            }
534            start_high += 1;
535            start_low = 0;
536        }
537
538        count
539    }
540
541    /// Add a bitmap for a single fragment
542    pub fn insert_bitmap(&mut self, fragment: u32, bitmap: RoaringBitmap) {
543        self.inner
544            .insert(fragment, RowAddrSelection::Partial(bitmap));
545    }
546
547    /// Add a whole fragment to the set
548    pub fn insert_fragment(&mut self, fragment_id: u32) {
549        self.inner.insert(fragment_id, RowAddrSelection::Full);
550    }
551
552    pub fn get_fragment_bitmap(&self, fragment_id: u32) -> Option<&RoaringBitmap> {
553        match self.inner.get(&fragment_id) {
554            None => None,
555            Some(RowAddrSelection::Full) => None,
556            Some(RowAddrSelection::Partial(set)) => Some(set),
557        }
558    }
559
560    pub fn retain_fragments(&mut self, frag_ids: impl IntoIterator<Item = u32>) {
561        let frag_id_set = frag_ids.into_iter().collect::<HashSet<_>>();
562        self.inner
563            .retain(|frag_id, _| frag_id_set.contains(frag_id));
564    }
565
566    /// Compute the serialized size of the set.
567    pub fn serialized_size(&self) -> usize {
568        // Starts at 4 because of the u32 num_entries
569        let mut size = 4;
570        for set in self.inner.values() {
571            // Each entry is 8 bytes for the fragment id and the bitmap size
572            size += 8;
573            if let RowAddrSelection::Partial(set) = set {
574                size += set.serialized_size();
575            }
576        }
577        size
578    }
579
580    /// Serialize the set into the given buffer
581    ///
582    /// The serialization format is stable and used for index serialization
583    ///
584    /// The serialization format is:
585    /// * u32: num_entries
586    ///
587    /// for each entry:
588    ///   * u32: fragment_id
589    ///   * u32: bitmap size
590    ///   * \[u8\]: bitmap
591    ///
592    /// If bitmap size is zero then the entire fragment is selected.
593    pub fn serialize_into<W: Write>(&self, mut writer: W) -> Result<()> {
594        writer.write_u32::<byteorder::LittleEndian>(self.inner.len() as u32)?;
595        for (fragment, set) in &self.inner {
596            writer.write_u32::<byteorder::LittleEndian>(*fragment)?;
597            if let RowAddrSelection::Partial(set) = set {
598                writer.write_u32::<byteorder::LittleEndian>(set.serialized_size() as u32)?;
599                set.serialize_into(&mut writer)?;
600            } else {
601                writer.write_u32::<byteorder::LittleEndian>(0)?;
602            }
603        }
604        Ok(())
605    }
606
607    /// Deserialize the set from the given buffer
608    pub fn deserialize_from<R: Read>(mut reader: R) -> Result<Self> {
609        let num_entries = reader.read_u32::<byteorder::LittleEndian>()?;
610        let mut inner = BTreeMap::new();
611        for _ in 0..num_entries {
612            let fragment = reader.read_u32::<byteorder::LittleEndian>()?;
613            let bitmap_size = reader.read_u32::<byteorder::LittleEndian>()?;
614            if bitmap_size == 0 {
615                inner.insert(fragment, RowAddrSelection::Full);
616            } else {
617                let mut buffer = vec![0; bitmap_size as usize];
618                reader.read_exact(&mut buffer)?;
619                let set = RoaringBitmap::deserialize_from(&buffer[..])?;
620                inner.insert(fragment, RowAddrSelection::Partial(set));
621            }
622        }
623        Ok(Self { inner })
624    }
625
626    /// Apply a mask to the row addrs
627    ///
628    /// For AllowList: only keep rows that are in the selection and not null
629    /// For BlockList: remove rows that are blocked (not null) and remove nulls
630    pub fn mask(&mut self, mask: &RowAddrMask) {
631        match mask {
632            RowAddrMask::AllowList(allow_list) => {
633                *self &= allow_list;
634            }
635            RowAddrMask::BlockList(block_list) => {
636                *self -= block_list;
637            }
638        }
639    }
640
641    /// Convert the set into an iterator of row addrs
642    ///
643    /// # Safety
644    ///
645    /// This is unsafe because if any of the inner RowAddrSelection elements
646    /// is not a Partial then the iterator will panic because we don't know
647    /// the size of the bitmap.
648    pub unsafe fn into_addr_iter(self) -> impl Iterator<Item = u64> {
649        self.inner
650            .into_iter()
651            .flat_map(|(fragment, selection)| match selection {
652                RowAddrSelection::Full => panic!("Size of full fragment is unknown"),
653                RowAddrSelection::Partial(bitmap) => bitmap.into_iter().map(move |val| {
654                    let fragment = fragment as u64;
655                    let row_offset = val as u64;
656                    (fragment << 32) | row_offset
657                }),
658            })
659    }
660}
661
662impl std::ops::BitOr<Self> for RowAddrTreeMap {
663    type Output = Self;
664
665    fn bitor(mut self, rhs: Self) -> Self::Output {
666        self |= rhs;
667        self
668    }
669}
670
671impl std::ops::BitOr<&Self> for RowAddrTreeMap {
672    type Output = Self;
673
674    fn bitor(mut self, rhs: &Self) -> Self::Output {
675        self |= rhs;
676        self
677    }
678}
679
680impl std::ops::BitOrAssign<Self> for RowAddrTreeMap {
681    fn bitor_assign(&mut self, rhs: Self) {
682        *self |= &rhs;
683    }
684}
685
686impl std::ops::BitOrAssign<&Self> for RowAddrTreeMap {
687    fn bitor_assign(&mut self, rhs: &Self) {
688        for (fragment, rhs_set) in &rhs.inner {
689            let lhs_set = self.inner.get_mut(fragment);
690            if let Some(lhs_set) = lhs_set {
691                match lhs_set {
692                    RowAddrSelection::Full => {
693                        // If the fragment is already selected then there is nothing to do
694                    }
695                    RowAddrSelection::Partial(lhs_bitmap) => match rhs_set {
696                        RowAddrSelection::Full => {
697                            *lhs_set = RowAddrSelection::Full;
698                        }
699                        RowAddrSelection::Partial(rhs_set) => {
700                            *lhs_bitmap |= rhs_set;
701                        }
702                    },
703                }
704            } else {
705                self.inner.insert(*fragment, rhs_set.clone());
706            }
707        }
708    }
709}
710
711impl std::ops::BitAnd<Self> for RowAddrTreeMap {
712    type Output = Self;
713
714    fn bitand(mut self, rhs: Self) -> Self::Output {
715        self &= &rhs;
716        self
717    }
718}
719
720impl std::ops::BitAnd<&Self> for RowAddrTreeMap {
721    type Output = Self;
722
723    fn bitand(mut self, rhs: &Self) -> Self::Output {
724        self &= rhs;
725        self
726    }
727}
728
729impl std::ops::BitAndAssign<Self> for RowAddrTreeMap {
730    fn bitand_assign(&mut self, rhs: Self) {
731        *self &= &rhs;
732    }
733}
734
735impl std::ops::BitAndAssign<&Self> for RowAddrTreeMap {
736    fn bitand_assign(&mut self, rhs: &Self) {
737        // Remove fragment that aren't on the RHS
738        self.inner
739            .retain(|fragment, _| rhs.inner.contains_key(fragment));
740
741        // For fragments that are on the RHS, intersect the bitmaps
742        for (fragment, mut lhs_set) in &mut self.inner {
743            match (&mut lhs_set, rhs.inner.get(fragment)) {
744                (_, None) => {} // Already handled by retain
745                (_, Some(RowAddrSelection::Full)) => {
746                    // Everything selected on RHS, so can leave LHS untouched.
747                }
748                (RowAddrSelection::Partial(lhs_set), Some(RowAddrSelection::Partial(rhs_set))) => {
749                    *lhs_set &= rhs_set;
750                }
751                (RowAddrSelection::Full, Some(RowAddrSelection::Partial(rhs_set))) => {
752                    *lhs_set = RowAddrSelection::Partial(rhs_set.clone());
753                }
754            }
755        }
756        // Some bitmaps might now be empty. If they are, we should remove them.
757        self.inner.retain(|_, set| match set {
758            RowAddrSelection::Partial(set) => !set.is_empty(),
759            RowAddrSelection::Full => true,
760        });
761    }
762}
763
764impl std::ops::Sub<Self> for RowAddrTreeMap {
765    type Output = Self;
766
767    fn sub(mut self, rhs: Self) -> Self {
768        self -= &rhs;
769        self
770    }
771}
772
773impl std::ops::Sub<&Self> for RowAddrTreeMap {
774    type Output = Self;
775
776    fn sub(mut self, rhs: &Self) -> Self {
777        self -= rhs;
778        self
779    }
780}
781
782impl std::ops::SubAssign<&Self> for RowAddrTreeMap {
783    fn sub_assign(&mut self, rhs: &Self) {
784        for (fragment, rhs_set) in &rhs.inner {
785            match self.inner.get_mut(fragment) {
786                None => {}
787                Some(RowAddrSelection::Full) => {
788                    // If the fragment is already selected then there is nothing to do
789                    match rhs_set {
790                        RowAddrSelection::Full => {
791                            self.inner.remove(fragment);
792                        }
793                        RowAddrSelection::Partial(rhs_set) => {
794                            // This generally won't be hit.
795                            let mut set = RoaringBitmap::full();
796                            set -= rhs_set;
797                            self.inner.insert(*fragment, RowAddrSelection::Partial(set));
798                        }
799                    }
800                }
801                Some(RowAddrSelection::Partial(lhs_set)) => match rhs_set {
802                    RowAddrSelection::Full => {
803                        self.inner.remove(fragment);
804                    }
805                    RowAddrSelection::Partial(rhs_set) => {
806                        *lhs_set -= rhs_set;
807                        if lhs_set.is_empty() {
808                            self.inner.remove(fragment);
809                        }
810                    }
811                },
812            }
813        }
814    }
815}
816
817impl FromIterator<u64> for RowAddrTreeMap {
818    fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> Self {
819        let mut inner = BTreeMap::new();
820        for row_addr in iter {
821            let upper = (row_addr >> 32) as u32;
822            let lower = row_addr as u32;
823            match inner.get_mut(&upper) {
824                None => {
825                    let mut set = RoaringBitmap::new();
826                    set.insert(lower);
827                    inner.insert(upper, RowAddrSelection::Partial(set));
828                }
829                Some(RowAddrSelection::Full) => {
830                    // If the fragment is already selected then there is nothing to do
831                }
832                Some(RowAddrSelection::Partial(set)) => {
833                    set.insert(lower);
834                }
835            }
836        }
837        Self { inner }
838    }
839}
840
841impl<'a> FromIterator<&'a u64> for RowAddrTreeMap {
842    fn from_iter<T: IntoIterator<Item = &'a u64>>(iter: T) -> Self {
843        Self::from_iter(iter.into_iter().copied())
844    }
845}
846
847impl From<Range<u64>> for RowAddrTreeMap {
848    fn from(range: Range<u64>) -> Self {
849        let mut map = Self::default();
850        map.insert_range(range);
851        map
852    }
853}
854
855impl From<RangeInclusive<u64>> for RowAddrTreeMap {
856    fn from(range: RangeInclusive<u64>) -> Self {
857        let mut map = Self::default();
858        map.insert_range(range);
859        map
860    }
861}
862
863impl From<RoaringTreemap> for RowAddrTreeMap {
864    fn from(roaring: RoaringTreemap) -> Self {
865        let mut inner = BTreeMap::new();
866        for (fragment, set) in roaring.bitmaps() {
867            inner.insert(fragment, RowAddrSelection::Partial(set.clone()));
868        }
869        Self { inner }
870    }
871}
872
873impl Extend<u64> for RowAddrTreeMap {
874    fn extend<T: IntoIterator<Item = u64>>(&mut self, iter: T) {
875        for row_addr in iter {
876            let upper = (row_addr >> 32) as u32;
877            let lower = row_addr as u32;
878            match self.inner.get_mut(&upper) {
879                None => {
880                    let mut set = RoaringBitmap::new();
881                    set.insert(lower);
882                    self.inner.insert(upper, RowAddrSelection::Partial(set));
883                }
884                Some(RowAddrSelection::Full) => {
885                    // If the fragment is already selected then there is nothing to do
886                }
887                Some(RowAddrSelection::Partial(set)) => {
888                    set.insert(lower);
889                }
890            }
891        }
892    }
893}
894
895impl<'a> Extend<&'a u64> for RowAddrTreeMap {
896    fn extend<T: IntoIterator<Item = &'a u64>>(&mut self, iter: T) {
897        self.extend(iter.into_iter().copied())
898    }
899}
900
901// Extending with RowAddrTreeMap is basically a cumulative set union
902impl Extend<Self> for RowAddrTreeMap {
903    fn extend<T: IntoIterator<Item = Self>>(&mut self, iter: T) {
904        for other in iter {
905            for (fragment, set) in other.inner {
906                match self.inner.get_mut(&fragment) {
907                    None => {
908                        self.inner.insert(fragment, set);
909                    }
910                    Some(RowAddrSelection::Full) => {
911                        // If the fragment is already selected then there is nothing to do
912                    }
913                    Some(RowAddrSelection::Partial(lhs_set)) => match set {
914                        RowAddrSelection::Full => {
915                            self.inner.insert(fragment, RowAddrSelection::Full);
916                        }
917                        RowAddrSelection::Partial(rhs_set) => {
918                            *lhs_set |= rhs_set;
919                        }
920                    },
921                }
922            }
923        }
924    }
925}
926
927#[cfg(test)]
928mod tests {
929    use super::*;
930    use proptest::prop_assert_eq;
931
932    fn rows(ids: &[u64]) -> RowAddrTreeMap {
933        RowAddrTreeMap::from_iter(ids)
934    }
935
936    fn assert_mask_selects(mask: &RowAddrMask, selected: &[u64], not_selected: &[u64]) {
937        for &id in selected {
938            assert!(mask.selected(id), "Expected row {} to be selected", id);
939        }
940        for &id in not_selected {
941            assert!(!mask.selected(id), "Expected row {} to NOT be selected", id);
942        }
943    }
944
945    fn selected_in_range(mask: &RowAddrMask, range: std::ops::Range<u64>) -> Vec<u64> {
946        range.filter(|val| mask.selected(*val)).collect()
947    }
948
949    #[test]
950    fn test_row_addr_mask_construction() {
951        let full_mask = RowAddrMask::all_rows();
952        assert_eq!(full_mask.max_len(), None);
953        assert_mask_selects(&full_mask, &[0, 1, 4 << 32 | 3], &[]);
954        assert_eq!(full_mask.allow_list(), None);
955        assert_eq!(full_mask.block_list(), Some(&RowAddrTreeMap::default()));
956        assert!(full_mask.iter_addrs().is_none());
957
958        let empty_mask = RowAddrMask::allow_nothing();
959        assert_eq!(empty_mask.max_len(), Some(0));
960        assert_mask_selects(&empty_mask, &[], &[0, 1, 4 << 32 | 3]);
961        assert_eq!(empty_mask.allow_list(), Some(&RowAddrTreeMap::default()));
962        assert_eq!(empty_mask.block_list(), None);
963        let iter = empty_mask.iter_addrs();
964        assert!(iter.is_some());
965        assert_eq!(iter.unwrap().count(), 0);
966
967        let allow_list = RowAddrMask::from_allowed(rows(&[10, 20, 30]));
968        assert_eq!(allow_list.max_len(), Some(3));
969        assert_mask_selects(&allow_list, &[10, 20, 30], &[0, 15, 25, 40]);
970        assert_eq!(allow_list.allow_list(), Some(&rows(&[10, 20, 30])));
971        assert_eq!(allow_list.block_list(), None);
972        let iter = allow_list.iter_addrs();
973        assert!(iter.is_some());
974        let ids: Vec<u64> = iter.unwrap().map(|addr| addr.into()).collect();
975        assert_eq!(ids, vec![10, 20, 30]);
976
977        let mut full_frag = RowAddrTreeMap::default();
978        full_frag.insert_fragment(2);
979        let allow_list = RowAddrMask::from_allowed(full_frag);
980        assert_eq!(allow_list.max_len(), None);
981        assert_mask_selects(&allow_list, &[(2 << 32) + 5], &[(3 << 32) + 5]);
982        assert!(allow_list.iter_addrs().is_none());
983    }
984
985    #[test]
986    fn test_selected_indices() {
987        // Allow list
988        let mask = RowAddrMask::from_allowed(rows(&[10, 20, 40]));
989        assert!(mask.selected_indices(std::iter::empty()).is_empty());
990        assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[1, 3]);
991
992        // Block list
993        let mask = RowAddrMask::from_block(rows(&[10, 20, 40]));
994        assert!(mask.selected_indices(std::iter::empty()).is_empty());
995        assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[0, 2]);
996    }
997
998    #[test]
999    fn test_also_allow() {
1000        // Allow list
1001        let mask = RowAddrMask::from_allowed(rows(&[10, 20]));
1002        let new_mask = mask.also_allow(rows(&[20, 30, 40]));
1003        assert_eq!(new_mask, RowAddrMask::from_allowed(rows(&[10, 20, 30, 40])));
1004
1005        // Block list
1006        let mask = RowAddrMask::from_block(rows(&[10, 20, 30]));
1007        let new_mask = mask.also_allow(rows(&[20, 40]));
1008        assert_eq!(new_mask, RowAddrMask::from_block(rows(&[10, 30])));
1009    }
1010
1011    #[test]
1012    fn test_also_block() {
1013        // Allow list
1014        let mask = RowAddrMask::from_allowed(rows(&[10, 20, 30]));
1015        let new_mask = mask.also_block(rows(&[20, 40]));
1016        assert_eq!(new_mask, RowAddrMask::from_allowed(rows(&[10, 30])));
1017
1018        // Block list
1019        let mask = RowAddrMask::from_block(rows(&[10, 20]));
1020        let new_mask = mask.also_block(rows(&[20, 30, 40]));
1021        assert_eq!(new_mask, RowAddrMask::from_block(rows(&[10, 20, 30, 40])));
1022    }
1023
1024    #[test]
1025    fn test_iter_ids() {
1026        // Allow list
1027        let mask = RowAddrMask::from_allowed(rows(&[10, 20, 30]));
1028        let expected: Vec<_> = [10, 20, 30].into_iter().map(RowAddress::from).collect();
1029        assert_eq!(mask.iter_addrs().unwrap().collect::<Vec<_>>(), expected);
1030
1031        // Allow list with full fragment
1032        let mut inner = RowAddrTreeMap::default();
1033        inner.insert_fragment(10);
1034        let mask = RowAddrMask::from_allowed(inner);
1035        assert!(mask.iter_addrs().is_none());
1036
1037        // Block list
1038        let mask = RowAddrMask::from_block(rows(&[10, 20, 30]));
1039        assert!(mask.iter_addrs().is_none());
1040    }
1041
1042    #[test]
1043    fn test_row_addr_mask_not() {
1044        let allow_list = RowAddrMask::from_allowed(rows(&[1, 2, 3]));
1045        let block_list = !allow_list.clone();
1046        assert_eq!(block_list, RowAddrMask::from_block(rows(&[1, 2, 3])));
1047        // Can roundtrip by negating again
1048        assert_eq!(!block_list, allow_list);
1049    }
1050
1051    #[test]
1052    fn test_ops() {
1053        let mask = RowAddrMask::default();
1054        assert_mask_selects(&mask, &[1, 5], &[]);
1055
1056        let block_list = mask.also_block(rows(&[0, 5, 15]));
1057        assert_mask_selects(&block_list, &[1], &[5]);
1058
1059        let allow_list = RowAddrMask::from_allowed(rows(&[0, 2, 5]));
1060        assert_mask_selects(&allow_list, &[5], &[1]);
1061
1062        let combined = block_list & allow_list;
1063        assert_mask_selects(&combined, &[2], &[0, 5]);
1064
1065        let other = RowAddrMask::from_allowed(rows(&[3]));
1066        let combined = combined | other;
1067        assert_mask_selects(&combined, &[2, 3], &[0, 5]);
1068
1069        let block_list = RowAddrMask::from_block(rows(&[0]));
1070        let allow_list = RowAddrMask::from_allowed(rows(&[3]));
1071
1072        let combined = block_list | allow_list;
1073        assert_mask_selects(&combined, &[1], &[]);
1074    }
1075
1076    #[test]
1077    fn test_logical_and() {
1078        let allow1 = RowAddrMask::from_allowed(rows(&[0, 1]));
1079        let block1 = RowAddrMask::from_block(rows(&[1, 2]));
1080        let allow2 = RowAddrMask::from_allowed(rows(&[1, 2, 3, 4]));
1081        let block2 = RowAddrMask::from_block(rows(&[3, 4]));
1082
1083        fn check(lhs: &RowAddrMask, rhs: &RowAddrMask, expected: &[u64]) {
1084            for mask in [lhs.clone() & rhs.clone(), rhs.clone() & lhs.clone()] {
1085                assert_eq!(selected_in_range(&mask, 0..10), expected);
1086            }
1087        }
1088
1089        // Allow & Allow
1090        check(&allow1, &allow1, &[0, 1]);
1091        check(&allow1, &allow2, &[1]);
1092
1093        // Block & Block
1094        check(&block1, &block1, &[0, 3, 4, 5, 6, 7, 8, 9]);
1095        check(&block1, &block2, &[0, 5, 6, 7, 8, 9]);
1096
1097        // Allow & Block
1098        check(&allow1, &block1, &[0]);
1099        check(&allow1, &block2, &[0, 1]);
1100        check(&allow2, &block1, &[3, 4]);
1101        check(&allow2, &block2, &[1, 2]);
1102    }
1103
1104    #[test]
1105    fn test_logical_or() {
1106        let allow1 = RowAddrMask::from_allowed(rows(&[5, 6, 7, 8, 9]));
1107        let block1 = RowAddrMask::from_block(rows(&[5, 6]));
1108        let mixed1 = allow1.clone().also_block(rows(&[5, 6]));
1109        let allow2 = RowAddrMask::from_allowed(rows(&[2, 3, 4, 5, 6, 7, 8]));
1110        let block2 = RowAddrMask::from_block(rows(&[4, 5]));
1111        let mixed2 = allow2.clone().also_block(rows(&[4, 5]));
1112
1113        fn check(lhs: &RowAddrMask, rhs: &RowAddrMask, expected: &[u64]) {
1114            for mask in [lhs.clone() | rhs.clone(), rhs.clone() | lhs.clone()] {
1115                assert_eq!(selected_in_range(&mask, 0..10), expected);
1116            }
1117        }
1118
1119        check(&allow1, &allow1, &[5, 6, 7, 8, 9]);
1120        check(&block1, &block1, &[0, 1, 2, 3, 4, 7, 8, 9]);
1121        check(&mixed1, &mixed1, &[7, 8, 9]);
1122        check(&allow2, &allow2, &[2, 3, 4, 5, 6, 7, 8]);
1123        check(&block2, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1124        check(&mixed2, &mixed2, &[2, 3, 6, 7, 8]);
1125
1126        check(&allow1, &block1, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1127        check(&allow1, &mixed1, &[5, 6, 7, 8, 9]);
1128        check(&allow1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
1129        check(&allow1, &block2, &[0, 1, 2, 3, 5, 6, 7, 8, 9]);
1130        check(&allow1, &mixed2, &[2, 3, 5, 6, 7, 8, 9]);
1131        check(&block1, &mixed1, &[0, 1, 2, 3, 4, 7, 8, 9]);
1132        check(&block1, &allow2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1133        check(&block1, &block2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1134        check(&block1, &mixed2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1135        check(&mixed1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
1136        check(&mixed1, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1137        check(&mixed1, &mixed2, &[2, 3, 6, 7, 8, 9]);
1138        check(&allow2, &block2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1139        check(&allow2, &mixed2, &[2, 3, 4, 5, 6, 7, 8]);
1140        check(&block2, &mixed2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1141    }
1142
1143    #[test]
1144    fn test_deserialize_legacy_format() {
1145        // Test that we can deserialize the old format where both allow_list
1146        // and block_list could be present in the serialized form.
1147        //
1148        // The old format (before this PR) used a struct with both allow_list and block_list
1149        // fields. The new format uses an enum. The deserialization code should handle
1150        // the case where both lists are present by converting to AllowList(allow - block).
1151
1152        // Create the RowIdTreeMaps and serialize them directly
1153        let allow = rows(&[1, 2, 3, 4, 5, 10, 15]);
1154        let block = rows(&[2, 4, 15]);
1155
1156        // Serialize using the stable RowIdTreeMap serialization format
1157        let block_bytes = {
1158            let mut buf = Vec::with_capacity(block.serialized_size());
1159            block.serialize_into(&mut buf).unwrap();
1160            buf
1161        };
1162        let allow_bytes = {
1163            let mut buf = Vec::with_capacity(allow.serialized_size());
1164            allow.serialize_into(&mut buf).unwrap();
1165            buf
1166        };
1167
1168        // Construct a binary array with both values present (simulating old format)
1169        let old_format_array =
1170            BinaryArray::from_opt_vec(vec![Some(&block_bytes), Some(&allow_bytes)]);
1171
1172        // Deserialize - should handle this by creating AllowList(allow - block)
1173        let deserialized = RowAddrMask::from_arrow(&old_format_array).unwrap();
1174
1175        // The expected result: AllowList([1, 2, 3, 4, 5, 10, 15] - [2, 4, 15]) = [1, 3, 5, 10]
1176        assert_mask_selects(&deserialized, &[1, 3, 5, 10], &[2, 4, 15]);
1177        assert!(
1178            deserialized.allow_list().is_some(),
1179            "Should deserialize to AllowList variant"
1180        );
1181    }
1182
1183    #[test]
1184    fn test_roundtrip_arrow() {
1185        let row_addrs = rows(&[1, 2, 3, 100, 2000]);
1186
1187        // Allow list
1188        let original = RowAddrMask::from_allowed(row_addrs.clone());
1189        let array = original.into_arrow().unwrap();
1190        assert_eq!(RowAddrMask::from_arrow(&array).unwrap(), original);
1191
1192        // Block list
1193        let original = RowAddrMask::from_block(row_addrs);
1194        let array = original.into_arrow().unwrap();
1195        assert_eq!(RowAddrMask::from_arrow(&array).unwrap(), original);
1196    }
1197
1198    #[test]
1199    fn test_deserialize_legacy_empty_lists() {
1200        // Case 1: Both None (should become all_rows)
1201        let array = BinaryArray::from_opt_vec(vec![None, None]);
1202        let mask = RowAddrMask::from_arrow(&array).unwrap();
1203        assert_mask_selects(&mask, &[0, 100, u64::MAX], &[]);
1204
1205        // Case 2: Only block list (no allow list)
1206        let block = rows(&[5, 10]);
1207        let block_bytes = {
1208            let mut buf = Vec::with_capacity(block.serialized_size());
1209            block.serialize_into(&mut buf).unwrap();
1210            buf
1211        };
1212        let array = BinaryArray::from_opt_vec(vec![Some(&block_bytes[..]), None]);
1213        let mask = RowAddrMask::from_arrow(&array).unwrap();
1214        assert_mask_selects(&mask, &[0, 15], &[5, 10]);
1215
1216        // Case 3: Only allow list (no block list)
1217        let allow = rows(&[5, 10]);
1218        let allow_bytes = {
1219            let mut buf = Vec::with_capacity(allow.serialized_size());
1220            allow.serialize_into(&mut buf).unwrap();
1221            buf
1222        };
1223        let array = BinaryArray::from_opt_vec(vec![None, Some(&allow_bytes[..])]);
1224        let mask = RowAddrMask::from_arrow(&array).unwrap();
1225        assert_mask_selects(&mask, &[5, 10], &[0, 15]);
1226    }
1227
1228    #[test]
1229    fn test_map_insert() {
1230        let mut map = RowAddrTreeMap::default();
1231
1232        assert!(!map.contains(20));
1233        assert!(map.insert(20));
1234        assert!(map.contains(20));
1235        assert!(!map.insert(20)); // Inserting again should be no-op
1236
1237        let bitmap = map.get_fragment_bitmap(0);
1238        assert!(bitmap.is_some());
1239        let bitmap = bitmap.unwrap();
1240        assert_eq!(bitmap.len(), 1);
1241
1242        assert!(map.get_fragment_bitmap(1).is_none());
1243
1244        map.insert_fragment(0);
1245        assert!(map.contains(0));
1246        assert!(!map.insert(0)); // Inserting into full fragment should be no-op
1247        assert!(map.get_fragment_bitmap(0).is_none());
1248    }
1249
1250    #[test]
1251    fn test_map_insert_range() {
1252        let ranges = &[
1253            (0..10),
1254            (40..500),
1255            ((u32::MAX as u64 - 10)..(u32::MAX as u64 + 20)),
1256        ];
1257
1258        for range in ranges {
1259            let mut mask = RowAddrTreeMap::default();
1260
1261            let count = mask.insert_range(range.clone());
1262            let expected = range.end - range.start;
1263            assert_eq!(count, expected);
1264
1265            let count = mask.insert_range(range.clone());
1266            assert_eq!(count, 0);
1267
1268            let new_range = range.start + 5..range.end + 5;
1269            let count = mask.insert_range(new_range.clone());
1270            assert_eq!(count, 5);
1271        }
1272
1273        let mut mask = RowAddrTreeMap::default();
1274        let count = mask.insert_range(..10);
1275        assert_eq!(count, 10);
1276        assert!(mask.contains(0));
1277
1278        let count = mask.insert_range(20..=24);
1279        assert_eq!(count, 5);
1280
1281        mask.insert_fragment(0);
1282        let count = mask.insert_range(100..200);
1283        assert_eq!(count, 0);
1284    }
1285
1286    #[test]
1287    fn test_map_remove() {
1288        let mut mask = RowAddrTreeMap::default();
1289
1290        assert!(!mask.remove(20));
1291
1292        mask.insert(20);
1293        assert!(mask.contains(20));
1294        assert!(mask.remove(20));
1295        assert!(!mask.contains(20));
1296
1297        mask.insert_range(10..=20);
1298        assert!(mask.contains(15));
1299        assert!(mask.remove(15));
1300        assert!(!mask.contains(15));
1301
1302        // We don't test removing from a full fragment, because that would take
1303        // a lot of memory.
1304    }
1305
1306    #[test]
1307    fn test_map_mask() {
1308        let mask = rows(&[0, 1, 2]);
1309        let mask2 = rows(&[0, 2, 3]);
1310
1311        let allow_list = RowAddrMask::AllowList(mask2.clone());
1312        let mut actual = mask.clone();
1313        actual.mask(&allow_list);
1314        assert_eq!(actual, rows(&[0, 2]));
1315
1316        let block_list = RowAddrMask::BlockList(mask2);
1317        let mut actual = mask;
1318        actual.mask(&block_list);
1319        assert_eq!(actual, rows(&[1]));
1320    }
1321
1322    #[test]
1323    #[should_panic(expected = "Size of full fragment is unknown")]
1324    fn test_map_insert_full_fragment_row() {
1325        let mut mask = RowAddrTreeMap::default();
1326        mask.insert_fragment(0);
1327
1328        unsafe {
1329            let _ = mask.into_addr_iter().collect::<Vec<u64>>();
1330        }
1331    }
1332
1333    #[test]
1334    fn test_map_into_addr_iter() {
1335        let mut mask = RowAddrTreeMap::default();
1336        mask.insert(0);
1337        mask.insert(1);
1338        mask.insert(1 << 32 | 5);
1339        mask.insert(2 << 32 | 10);
1340
1341        let expected = vec![0u64, 1, 1 << 32 | 5, 2 << 32 | 10];
1342        let actual: Vec<u64> = unsafe { mask.into_addr_iter().collect() };
1343        assert_eq!(actual, expected);
1344    }
1345
1346    #[test]
1347    fn test_map_from() {
1348        let map = RowAddrTreeMap::from(10..12);
1349        assert!(map.contains(10));
1350        assert!(map.contains(11));
1351        assert!(!map.contains(12));
1352        assert!(!map.contains(3));
1353
1354        let map = RowAddrTreeMap::from(10..=12);
1355        assert!(map.contains(10));
1356        assert!(map.contains(11));
1357        assert!(map.contains(12));
1358        assert!(!map.contains(3));
1359    }
1360
1361    #[test]
1362    fn test_map_from_roaring() {
1363        let bitmap = RoaringTreemap::from_iter(&[0, 1, 1 << 32]);
1364        let map = RowAddrTreeMap::from(bitmap);
1365        assert!(map.contains(0) && map.contains(1) && map.contains(1 << 32));
1366        assert!(!map.contains(2));
1367    }
1368
1369    #[test]
1370    fn test_map_extend() {
1371        let mut map = RowAddrTreeMap::default();
1372        map.insert(0);
1373        map.insert_fragment(1);
1374
1375        let other_rows = [0, 2, 1 << 32 | 10, 3 << 32 | 5];
1376        map.extend(other_rows.iter().copied());
1377
1378        assert!(map.contains(0));
1379        assert!(map.contains(2));
1380        assert!(map.contains(1 << 32 | 5));
1381        assert!(map.contains(1 << 32 | 10));
1382        assert!(map.contains(3 << 32 | 5));
1383        assert!(!map.contains(3));
1384    }
1385
1386    #[test]
1387    fn test_map_extend_other_maps() {
1388        let mut map = RowAddrTreeMap::default();
1389        map.insert(0);
1390        map.insert_fragment(1);
1391        map.insert(4 << 32);
1392
1393        let mut other_map = rows(&[0, 2, 1 << 32 | 10, 3 << 32 | 5]);
1394        other_map.insert_fragment(4);
1395        map.extend(std::iter::once(other_map));
1396
1397        for id in [
1398            0,
1399            2,
1400            1 << 32 | 5,
1401            1 << 32 | 10,
1402            3 << 32 | 5,
1403            4 << 32,
1404            4 << 32 | 7,
1405        ] {
1406            assert!(map.contains(id), "Expected {} to be contained", id);
1407        }
1408        assert!(!map.contains(3));
1409    }
1410
1411    proptest::proptest! {
1412        #[test]
1413        fn test_map_serialization_roundtrip(
1414            values in proptest::collection::vec(
1415                (0..u32::MAX, proptest::option::of(proptest::collection::vec(0..u32::MAX, 0..1000))),
1416                0..10
1417            )
1418        ) {
1419            let mut mask = RowAddrTreeMap::default();
1420            for (fragment, rows) in values {
1421                if let Some(rows) = rows {
1422                    let bitmap = RoaringBitmap::from_iter(rows);
1423                    mask.insert_bitmap(fragment, bitmap);
1424                } else {
1425                    mask.insert_fragment(fragment);
1426                }
1427            }
1428
1429            let mut data = Vec::new();
1430            mask.serialize_into(&mut data).unwrap();
1431            let deserialized = RowAddrTreeMap::deserialize_from(data.as_slice()).unwrap();
1432            prop_assert_eq!(mask, deserialized);
1433        }
1434
1435        #[test]
1436        fn test_map_intersect(
1437            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1438            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1439            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1440            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1441        ) {
1442            let mut left = RowAddrTreeMap::default();
1443            for fragment in left_full_fragments.clone() {
1444                left.insert_fragment(fragment);
1445            }
1446            left.extend(left_rows.iter().copied());
1447
1448            let mut right = RowAddrTreeMap::default();
1449            for fragment in right_full_fragments.clone() {
1450                right.insert_fragment(fragment);
1451            }
1452            right.extend(right_rows.iter().copied());
1453
1454            let mut expected = RowAddrTreeMap::default();
1455            for fragment in &left_full_fragments {
1456                if right_full_fragments.contains(fragment) {
1457                    expected.insert_fragment(*fragment);
1458                }
1459            }
1460
1461            let left_in_right = left_rows.iter().filter(|row| {
1462                right_rows.contains(row)
1463                    || right_full_fragments.contains(&((*row >> 32) as u32))
1464            });
1465            expected.extend(left_in_right);
1466            let right_in_left = right_rows.iter().filter(|row| {
1467                left_rows.contains(row)
1468                    || left_full_fragments.contains(&((*row >> 32) as u32))
1469            });
1470            expected.extend(right_in_left);
1471
1472            let actual = left & right;
1473            prop_assert_eq!(expected, actual);
1474        }
1475
1476        #[test]
1477        fn test_map_union(
1478            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1479            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1480            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1481            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1482        ) {
1483            let mut left = RowAddrTreeMap::default();
1484            for fragment in left_full_fragments.clone() {
1485                left.insert_fragment(fragment);
1486            }
1487            left.extend(left_rows.iter().copied());
1488
1489            let mut right = RowAddrTreeMap::default();
1490            for fragment in right_full_fragments.clone() {
1491                right.insert_fragment(fragment);
1492            }
1493            right.extend(right_rows.iter().copied());
1494
1495            let mut expected = RowAddrTreeMap::default();
1496            for fragment in left_full_fragments {
1497                expected.insert_fragment(fragment);
1498            }
1499            for fragment in right_full_fragments {
1500                expected.insert_fragment(fragment);
1501            }
1502
1503            let combined_rows = left_rows.iter().chain(right_rows.iter());
1504            expected.extend(combined_rows);
1505
1506            let actual = left | right;
1507            for actual_key_val in &actual.inner {
1508                proptest::prop_assert!(expected.inner.contains_key(actual_key_val.0));
1509                let expected_val = expected.inner.get(actual_key_val.0).unwrap();
1510                prop_assert_eq!(
1511                    actual_key_val.1,
1512                    expected_val,
1513                    "error on key {}",
1514                    actual_key_val.0
1515                );
1516            }
1517            prop_assert_eq!(expected, actual);
1518        }
1519
1520        #[test]
1521        fn test_map_subassign_rows(
1522            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1523            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1524            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1525        ) {
1526            let mut left = RowAddrTreeMap::default();
1527            for fragment in left_full_fragments {
1528                left.insert_fragment(fragment);
1529            }
1530            left.extend(left_rows.iter().copied());
1531
1532            let mut right = RowAddrTreeMap::default();
1533            right.extend(right_rows.iter().copied());
1534
1535            let mut expected = left.clone();
1536            for row in right_rows {
1537                expected.remove(row);
1538            }
1539
1540            left -= &right;
1541            prop_assert_eq!(expected, left);
1542        }
1543
1544        #[test]
1545        fn test_map_subassign_frags(
1546            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1547            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1548            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1549        ) {
1550            let mut left = RowAddrTreeMap::default();
1551            for fragment in left_full_fragments {
1552                left.insert_fragment(fragment);
1553            }
1554            left.extend(left_rows.iter().copied());
1555
1556            let mut right = RowAddrTreeMap::default();
1557            for fragment in right_full_fragments.clone() {
1558                right.insert_fragment(fragment);
1559            }
1560
1561            let mut expected = left.clone();
1562            for fragment in right_full_fragments {
1563                expected.inner.remove(&fragment);
1564            }
1565
1566            left -= &right;
1567            prop_assert_eq!(expected, left);
1568        }
1569
1570        #[test]
1571        fn test_from_sorted_iter(
1572            mut rows in proptest::collection::vec(0..u64::MAX, 0..1000)
1573        ) {
1574            rows.sort();
1575            let num_rows = rows.len();
1576            let mask = RowAddrTreeMap::from_sorted_iter(rows).unwrap();
1577            prop_assert_eq!(mask.len(), Some(num_rows as u64));
1578        }
1579
1580
1581    }
1582
1583    #[test]
1584    fn test_row_addr_selection_deep_size_of() {
1585        use deepsize::DeepSizeOf;
1586
1587        // Test Full variant - should have minimal size (just the enum discriminant)
1588        let full = RowAddrSelection::Full;
1589        let full_size = full.deep_size_of();
1590        // Full variant has no heap allocations beyond the enum itself
1591        assert!(full_size < 100); // Small sanity check
1592
1593        // Test Partial variant - should include bitmap size
1594        let mut bitmap = RoaringBitmap::new();
1595        bitmap.insert_range(0..100);
1596        let partial = RowAddrSelection::Partial(bitmap.clone());
1597        let partial_size = partial.deep_size_of();
1598        // Partial variant should be larger due to bitmap
1599        assert!(partial_size >= bitmap.serialized_size());
1600    }
1601
1602    #[test]
1603    fn test_row_addr_selection_union_all_with_full() {
1604        let full = RowAddrSelection::Full;
1605        let partial = RowAddrSelection::Partial(RoaringBitmap::from_iter(&[1, 2, 3]));
1606
1607        assert!(matches!(
1608            RowAddrSelection::union_all(&[&full, &partial]),
1609            RowAddrSelection::Full
1610        ));
1611
1612        let partial2 = RowAddrSelection::Partial(RoaringBitmap::from_iter(&[4, 5, 6]));
1613        let RowAddrSelection::Partial(bitmap) = RowAddrSelection::union_all(&[&partial, &partial2])
1614        else {
1615            panic!("Expected Partial");
1616        };
1617        assert!(bitmap.contains(1) && bitmap.contains(4));
1618    }
1619
1620    #[test]
1621    fn test_insert_range_unbounded_start() {
1622        let mut map = RowAddrTreeMap::default();
1623
1624        // Test exclusive start bound
1625        let count = map.insert_range((std::ops::Bound::Excluded(5), std::ops::Bound::Included(10)));
1626        assert_eq!(count, 5); // 6, 7, 8, 9, 10
1627        assert!(!map.contains(5));
1628        assert!(map.contains(6));
1629        assert!(map.contains(10));
1630
1631        // Test unbounded end
1632        let mut map2 = RowAddrTreeMap::default();
1633        let count = map2.insert_range(0..5);
1634        assert_eq!(count, 5);
1635        assert!(map2.contains(0));
1636        assert!(map2.contains(4));
1637        assert!(!map2.contains(5));
1638    }
1639
1640    #[test]
1641    fn test_remove_from_full_fragment() {
1642        let mut map = RowAddrTreeMap::default();
1643        map.insert_fragment(0);
1644
1645        // Verify it's a full fragment - get_fragment_bitmap returns None for Full
1646        for id in [0, 100, u32::MAX as u64] {
1647            assert!(map.contains(id));
1648        }
1649        assert!(map.get_fragment_bitmap(0).is_none());
1650
1651        // Remove a value from the full fragment
1652        assert!(map.remove(50));
1653
1654        // Now it should be partial (a full RoaringBitmap minus one value)
1655        assert!(map.contains(0) && !map.contains(50) && map.contains(100));
1656        assert!(map.get_fragment_bitmap(0).is_some());
1657    }
1658
1659    #[test]
1660    fn test_retain_fragments() {
1661        let mut map = RowAddrTreeMap::default();
1662        map.insert(0); // fragment 0
1663        map.insert(1 << 32 | 5); // fragment 1
1664        map.insert(2 << 32 | 10); // fragment 2
1665        map.insert_fragment(3); // fragment 3
1666
1667        map.retain_fragments([0, 2]);
1668
1669        assert!(map.contains(0) && map.contains(2 << 32 | 10));
1670        assert!(!map.contains(1 << 32 | 5) && !map.contains(3 << 32));
1671    }
1672
1673    #[test]
1674    fn test_bitor_assign_full_fragment() {
1675        // Test BitOrAssign when LHS has Full and RHS has Partial
1676        let mut map1 = RowAddrTreeMap::default();
1677        map1.insert_fragment(0);
1678        let mut map2 = RowAddrTreeMap::default();
1679        map2.insert(5);
1680
1681        map1 |= &map2;
1682        // Full | Partial = Full
1683        assert!(map1.contains(0) && map1.contains(5) && map1.contains(100));
1684
1685        // Test BitOrAssign when LHS has Partial and RHS has Full
1686        let mut map3 = RowAddrTreeMap::default();
1687        map3.insert(5);
1688        let mut map4 = RowAddrTreeMap::default();
1689        map4.insert_fragment(0);
1690
1691        map3 |= &map4;
1692        // Partial | Full = Full
1693        assert!(map3.contains(0) && map3.contains(5) && map3.contains(100));
1694    }
1695
1696    #[test]
1697    fn test_bitand_assign_full_fragments() {
1698        // Test BitAndAssign when both have Full for same fragment
1699        let mut map1 = RowAddrTreeMap::default();
1700        map1.insert_fragment(0);
1701        let mut map2 = RowAddrTreeMap::default();
1702        map2.insert_fragment(0);
1703
1704        map1 &= &map2;
1705        // Full & Full = Full
1706        assert!(map1.contains(0) && map1.contains(100));
1707
1708        // Test BitAndAssign when LHS Full, RHS Partial
1709        let mut map3 = RowAddrTreeMap::default();
1710        map3.insert_fragment(0);
1711        let mut map4 = RowAddrTreeMap::default();
1712        map4.insert(5);
1713        map4.insert(10);
1714
1715        map3 &= &map4;
1716        // Full & Partial([5,10]) = Partial([5,10])
1717        assert!(map3.contains(5) && map3.contains(10));
1718        assert!(!map3.contains(0) && !map3.contains(100));
1719
1720        // Test that empty intersection results in removal
1721        let mut map5 = RowAddrTreeMap::default();
1722        map5.insert(5);
1723        let mut map6 = RowAddrTreeMap::default();
1724        map6.insert(10);
1725
1726        map5 &= &map6;
1727        assert!(map5.is_empty());
1728    }
1729
1730    #[test]
1731    fn test_sub_assign_with_full_fragments() {
1732        // Test SubAssign when LHS is Full and RHS is Partial
1733        let mut map1 = RowAddrTreeMap::default();
1734        map1.insert_fragment(0);
1735        let mut map2 = RowAddrTreeMap::default();
1736        map2.insert(5);
1737        map2.insert(10);
1738
1739        map1 -= &map2;
1740        // Full - Partial([5,10]) = Full minus those values
1741        assert!(map1.contains(0) && map1.contains(100));
1742        assert!(!map1.contains(5) && !map1.contains(10));
1743
1744        // Test SubAssign when both are Full for same fragment
1745        let mut map3 = RowAddrTreeMap::default();
1746        map3.insert_fragment(0);
1747        let mut map4 = RowAddrTreeMap::default();
1748        map4.insert_fragment(0);
1749
1750        map3 -= &map4;
1751        // Full - Full = empty
1752        assert!(map3.is_empty());
1753
1754        // Test SubAssign when LHS is Partial and RHS is Full
1755        let mut map5 = RowAddrTreeMap::default();
1756        map5.insert(5);
1757        map5.insert(10);
1758        let mut map6 = RowAddrTreeMap::default();
1759        map6.insert_fragment(0);
1760
1761        map5 -= &map6;
1762        // Partial - Full = empty
1763        assert!(map5.is_empty());
1764    }
1765
1766    #[test]
1767    fn test_from_iterator_with_full_fragment() {
1768        // Test that inserting into a full fragment is a no-op
1769        let mut map = RowAddrTreeMap::default();
1770        map.insert_fragment(0);
1771
1772        // Extend with values that would go into fragment 0
1773        map.extend([5u64, 10, 100].iter());
1774
1775        // Should still be full fragment
1776        for id in [0, 5, 10, 100, u32::MAX as u64] {
1777            assert!(map.contains(id));
1778        }
1779    }
1780
1781    #[test]
1782    fn test_insert_range_excluded_end() {
1783        // Test excluded end bound (line 391-393)
1784        let mut map = RowAddrTreeMap::default();
1785        // Using RangeFrom with small range won't hit the unbounded case
1786        // Instead test Bound::Excluded for end
1787        let count = map.insert_range((std::ops::Bound::Included(5), std::ops::Bound::Excluded(10)));
1788        assert_eq!(count, 5); // 5, 6, 7, 8, 9
1789        assert!(map.contains(5));
1790        assert!(map.contains(9));
1791        assert!(!map.contains(10));
1792    }
1793
1794    #[test]
1795    fn test_bitand_assign_owned() {
1796        // Test BitAndAssign<Self> (owned, not reference)
1797        let mut map1 = RowAddrTreeMap::default();
1798        map1.insert(5);
1799        map1.insert(10);
1800
1801        // Using owned rhs (not reference)
1802        map1 &= rows(&[5, 15]);
1803
1804        assert!(map1.contains(5));
1805        assert!(!map1.contains(10) && !map1.contains(15));
1806    }
1807
1808    #[test]
1809    fn test_from_iter_with_full_fragment() {
1810        // When we collect into RowAddrTreeMap, it should handle duplicates
1811        let map: RowAddrTreeMap = vec![5u64, 10, 100].into_iter().collect();
1812        assert!(map.contains(5) && map.contains(10));
1813
1814        // Test that extending a map with full fragment ignores new values
1815        let mut map = RowAddrTreeMap::default();
1816        map.insert_fragment(0);
1817        for val in [5, 10, 100] {
1818            map.insert(val); // This should be no-op since fragment is full
1819        }
1820        // Still full fragment
1821        for id in [0, 5, u32::MAX as u64] {
1822            assert!(map.contains(id));
1823        }
1824    }
1825}