Skip to main content

lance_core/utils/
mask.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4use std::collections::HashSet;
5use std::io::Write;
6use std::ops::{Range, RangeBounds, RangeInclusive};
7use std::{collections::BTreeMap, io::Read};
8
9use arrow_array::{Array, BinaryArray, GenericBinaryArray};
10use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer};
11use byteorder::{ReadBytesExt, WriteBytesExt};
12use deepsize::DeepSizeOf;
13use itertools::Itertools;
14use roaring::{MultiOps, RoaringBitmap, RoaringTreemap};
15
16use crate::{Error, Result};
17
18use super::address::RowAddress;
19
20mod nullable;
21
22pub use nullable::{NullableRowAddrMask, NullableRowAddrSet};
23
24/// A mask that selects or deselects rows based on an allow-list or block-list.
25#[derive(Clone, Debug, DeepSizeOf, PartialEq)]
26pub enum RowAddrMask {
27    AllowList(RowAddrTreeMap),
28    BlockList(RowAddrTreeMap),
29}
30
31impl Default for RowAddrMask {
32    fn default() -> Self {
33        // Empty block list means all rows are allowed
34        Self::BlockList(RowAddrTreeMap::new())
35    }
36}
37
38impl RowAddrMask {
39    // Create a mask allowing all rows, this is an alias for [default]
40    pub fn all_rows() -> Self {
41        Self::default()
42    }
43
44    // Create a mask that doesn't allow anything
45    pub fn allow_nothing() -> Self {
46        Self::AllowList(RowAddrTreeMap::new())
47    }
48
49    // Create a mask from an allow list
50    pub fn from_allowed(allow_list: RowAddrTreeMap) -> Self {
51        Self::AllowList(allow_list)
52    }
53
54    // Create a mask from a block list
55    pub fn from_block(block_list: RowAddrTreeMap) -> Self {
56        Self::BlockList(block_list)
57    }
58
59    pub fn block_list(&self) -> Option<&RowAddrTreeMap> {
60        match self {
61            Self::BlockList(block_list) => Some(block_list),
62            _ => None,
63        }
64    }
65
66    pub fn allow_list(&self) -> Option<&RowAddrTreeMap> {
67        match self {
68            Self::AllowList(allow_list) => Some(allow_list),
69            _ => None,
70        }
71    }
72
73    /// True if the row_id is selected by the mask, false otherwise
74    pub fn selected(&self, row_id: u64) -> bool {
75        match self {
76            Self::AllowList(allow_list) => allow_list.contains(row_id),
77            Self::BlockList(block_list) => !block_list.contains(row_id),
78        }
79    }
80
81    /// Return the indices of the input row ids that were valid
82    pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
83        row_ids
84            .enumerate()
85            .filter_map(|(idx, row_id)| {
86                if self.selected(*row_id) {
87                    Some(idx as u64)
88                } else {
89                    None
90                }
91            })
92            .collect()
93    }
94
95    /// Also block the given addrs
96    pub fn also_block(self, block_list: RowAddrTreeMap) -> Self {
97        match self {
98            Self::AllowList(allow_list) => Self::AllowList(allow_list - block_list),
99            Self::BlockList(existing) => Self::BlockList(existing | block_list),
100        }
101    }
102
103    /// Also allow the given addrs
104    pub fn also_allow(self, allow_list: RowAddrTreeMap) -> Self {
105        match self {
106            Self::AllowList(existing) => Self::AllowList(existing | allow_list),
107            Self::BlockList(block_list) => Self::BlockList(block_list - allow_list),
108        }
109    }
110
111    /// Convert a mask into an arrow array
112    ///
113    /// A row addr mask is not very arrow-compatible.  We can't make it a batch with
114    /// two columns because the block list and allow list will have different lengths.  Also,
115    /// there is no Arrow type for compressed bitmaps.
116    ///
117    /// However, we need to shove it into some kind of Arrow container to pass it along the
118    /// datafusion stream.  Perhaps, in the future, we can add row addr masks as first class
119    /// types in datafusion, and this can be passed along as a mask / selection vector.
120    ///
121    /// We serialize this as a variable length binary array with two items.  The first item
122    /// is the block list and the second item is the allow list.
123    pub fn into_arrow(&self) -> Result<BinaryArray> {
124        // NOTE: This serialization format must be stable as it is used in IPC.
125        let (block_list, allow_list) = match self {
126            Self::AllowList(allow_list) => (None, Some(allow_list)),
127            Self::BlockList(block_list) => (Some(block_list), None),
128        };
129
130        let block_list_length = block_list
131            .as_ref()
132            .map(|bl| bl.serialized_size())
133            .unwrap_or(0);
134        let allow_list_length = allow_list
135            .as_ref()
136            .map(|al| al.serialized_size())
137            .unwrap_or(0);
138        let lengths = vec![block_list_length, allow_list_length];
139        let offsets = OffsetBuffer::from_lengths(lengths);
140        let mut value_bytes = vec![0; block_list_length + allow_list_length];
141        let mut validity = vec![false, false];
142        if let Some(block_list) = &block_list {
143            validity[0] = true;
144            block_list.serialize_into(&mut value_bytes[0..])?;
145        }
146        if let Some(allow_list) = &allow_list {
147            validity[1] = true;
148            allow_list.serialize_into(&mut value_bytes[block_list_length..])?;
149        }
150        let values = Buffer::from(value_bytes);
151        let nulls = NullBuffer::from(validity);
152        Ok(BinaryArray::try_new(offsets, values, Some(nulls))?)
153    }
154
155    /// Deserialize a row address mask from Arrow
156    pub fn from_arrow(array: &GenericBinaryArray<i32>) -> Result<Self> {
157        let block_list = if array.is_null(0) {
158            None
159        } else {
160            Some(RowAddrTreeMap::deserialize_from(array.value(0)))
161        }
162        .transpose()?;
163
164        let allow_list = if array.is_null(1) {
165            None
166        } else {
167            Some(RowAddrTreeMap::deserialize_from(array.value(1)))
168        }
169        .transpose()?;
170
171        let res = match (block_list, allow_list) {
172            (Some(bl), None) => Self::BlockList(bl),
173            (None, Some(al)) => Self::AllowList(al),
174            (Some(block), Some(allow)) => Self::AllowList(allow).also_block(block),
175            (None, None) => Self::all_rows(),
176        };
177        Ok(res)
178    }
179
180    /// Return the maximum number of row addresses that could be selected by this mask
181    ///
182    /// Will be None if this is a BlockList (unbounded)
183    pub fn max_len(&self) -> Option<u64> {
184        match self {
185            Self::AllowList(selection) => selection.len(),
186            Self::BlockList(_) => None,
187        }
188    }
189
190    /// Iterate over the row addresses that are selected by the mask
191    ///
192    /// This is only possible if this is an AllowList and the maps don't contain
193    /// any "full fragment" blocks.
194    pub fn iter_addrs(&self) -> Option<Box<dyn Iterator<Item = RowAddress> + '_>> {
195        match self {
196            Self::AllowList(allow_list) => {
197                if let Some(allow_iter) = allow_list.row_addrs() {
198                    Some(Box::new(allow_iter))
199                } else {
200                    None
201                }
202            }
203            Self::BlockList(_) => None, // Can't iterate over block list
204        }
205    }
206}
207
208impl std::ops::Not for RowAddrMask {
209    type Output = Self;
210
211    fn not(self) -> Self::Output {
212        match self {
213            Self::AllowList(allow_list) => Self::BlockList(allow_list),
214            Self::BlockList(block_list) => Self::AllowList(block_list),
215        }
216    }
217}
218
219impl std::ops::BitAnd for RowAddrMask {
220    type Output = Self;
221
222    fn bitand(self, rhs: Self) -> Self::Output {
223        match (self, rhs) {
224            (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a & b),
225            (Self::AllowList(allow), Self::BlockList(block))
226            | (Self::BlockList(block), Self::AllowList(allow)) => Self::AllowList(allow - block),
227            (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a | b),
228        }
229    }
230}
231
232impl std::ops::BitOr for RowAddrMask {
233    type Output = Self;
234
235    fn bitor(self, rhs: Self) -> Self::Output {
236        match (self, rhs) {
237            (Self::AllowList(a), Self::AllowList(b)) => Self::AllowList(a | b),
238            (Self::AllowList(allow), Self::BlockList(block))
239            | (Self::BlockList(block), Self::AllowList(allow)) => Self::BlockList(block - allow),
240            (Self::BlockList(a), Self::BlockList(b)) => Self::BlockList(a & b),
241        }
242    }
243}
244
245/// Common operations over a set of rows (either row ids or row addresses).
246///
247/// The concrete representation can be address-based (`RowAddrTreeMap`) or
248/// id-based (for example a future `RowIdSet`), but the semantics are the same:
249/// a set of unique rows.
250pub trait RowSetOps: Clone + Sized {
251    /// Logical row handle (`u64` for both row ids and row addresses).
252    type Row;
253
254    /// Returns true if the set is empty.
255    fn is_empty(&self) -> bool;
256
257    /// Returns the number of rows in the set, if it is known.
258    ///
259    /// Implementations that cannot always compute an exact size (for example
260    /// because of "full fragment" markers) should return `None`.
261    fn len(&self) -> Option<u64>;
262
263    /// Remove a value from the row set.
264    fn remove(&mut self, row: Self::Row) -> bool;
265
266    /// Returns whether this set contains the given row.
267    fn contains(&self, row: Self::Row) -> bool;
268
269    /// Returns the union of `other` and init self.
270    fn union_all(other: &[&Self]) -> Self;
271
272    /// Builds a row set from an iterator of rows.
273    fn from_sorted_iter<I>(iter: I) -> Result<Self>
274    where
275        I: IntoIterator<Item = Self::Row>;
276}
277
278/// A collection of row addresses.
279///
280/// Note: For stable row id mode, this may be split into a separate structure in the future.
281///
282/// These row ids may either be stable-style (where they can be an incrementing
283/// u64 sequence) or address style, where they are a fragment id and a row offset.
284/// When address style, this supports setting entire fragments as selected,
285/// without needing to enumerate all the ids in the fragment.
286///
287/// This is similar to a [RoaringTreemap] but it is optimized for the case where
288/// entire fragments are selected or deselected.
289#[derive(Clone, Debug, Default, PartialEq, DeepSizeOf)]
290pub struct RowAddrTreeMap {
291    /// The contents of the set. If there is a pair (k, Full) then the entire
292    /// fragment k is selected. If there is a pair (k, Partial(v)) then the
293    /// fragment k has the selected rows in v.
294    inner: BTreeMap<u32, RowAddrSelection>,
295}
296
297#[derive(Clone, Debug, PartialEq)]
298pub enum RowAddrSelection {
299    Full,
300    Partial(RoaringBitmap),
301}
302
303impl DeepSizeOf for RowAddrSelection {
304    fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize {
305        match self {
306            Self::Full => 0,
307            Self::Partial(bitmap) => bitmap.serialized_size(),
308        }
309    }
310}
311
312impl RowAddrSelection {
313    fn union_all(selections: &[&Self]) -> Self {
314        let mut is_full = false;
315
316        let res = Self::Partial(
317            selections
318                .iter()
319                .filter_map(|selection| match selection {
320                    Self::Full => {
321                        is_full = true;
322                        None
323                    }
324                    Self::Partial(bitmap) => Some(bitmap),
325                })
326                .union(),
327        );
328
329        if is_full { Self::Full } else { res }
330    }
331}
332
333impl RowSetOps for RowAddrTreeMap {
334    type Row = u64;
335
336    fn is_empty(&self) -> bool {
337        self.inner.is_empty()
338    }
339
340    fn len(&self) -> Option<u64> {
341        self.inner
342            .values()
343            .map(|row_addr_selection| match row_addr_selection {
344                RowAddrSelection::Full => None,
345                RowAddrSelection::Partial(indices) => Some(indices.len()),
346            })
347            .try_fold(0_u64, |acc, next| next.map(|next| next + acc))
348    }
349
350    fn remove(&mut self, row: Self::Row) -> bool {
351        let upper = (row >> 32) as u32;
352        let lower = row as u32;
353        match self.inner.get_mut(&upper) {
354            None => false,
355            Some(RowAddrSelection::Full) => {
356                let mut set = RoaringBitmap::full();
357                set.remove(lower);
358                self.inner.insert(upper, RowAddrSelection::Partial(set));
359                true
360            }
361            Some(RowAddrSelection::Partial(lower_set)) => {
362                let removed = lower_set.remove(lower);
363                if lower_set.is_empty() {
364                    self.inner.remove(&upper);
365                }
366                removed
367            }
368        }
369    }
370
371    fn contains(&self, row: Self::Row) -> bool {
372        let upper = (row >> 32) as u32;
373        let lower = row as u32;
374        match self.inner.get(&upper) {
375            None => false,
376            Some(RowAddrSelection::Full) => true,
377            Some(RowAddrSelection::Partial(fragment_set)) => fragment_set.contains(lower),
378        }
379    }
380
381    fn union_all(other: &[&Self]) -> Self {
382        let mut new_map = BTreeMap::new();
383
384        for map in other {
385            for (fragment, selection) in &map.inner {
386                new_map
387                    .entry(fragment)
388                    // I hate this allocation, but I can't think of a better way
389                    .or_insert_with(|| Vec::with_capacity(other.len()))
390                    .push(selection);
391            }
392        }
393
394        let new_map = new_map
395            .into_iter()
396            .map(|(&fragment, selections)| (fragment, RowAddrSelection::union_all(&selections)))
397            .collect();
398
399        Self { inner: new_map }
400    }
401
402    #[track_caller]
403    fn from_sorted_iter<I>(iter: I) -> Result<Self>
404    where
405        I: IntoIterator<Item = Self::Row>,
406    {
407        let mut iter = iter.into_iter().peekable();
408        let mut inner = BTreeMap::new();
409
410        while let Some(row_id) = iter.peek() {
411            let fragment_id = (row_id >> 32) as u32;
412            let next_bitmap_iter = iter
413                .peeking_take_while(|row_id| (row_id >> 32) as u32 == fragment_id)
414                .map(|row_id| row_id as u32);
415            let Ok(bitmap) = RoaringBitmap::from_sorted_iter(next_bitmap_iter) else {
416                return Err(Error::internal(
417                    "RowAddrTreeMap::from_sorted_iter called with non-sorted input",
418                ));
419            };
420            inner.insert(fragment_id, RowAddrSelection::Partial(bitmap));
421        }
422
423        Ok(Self { inner })
424    }
425}
426
427impl RowAddrTreeMap {
428    /// Create an empty set
429    pub fn new() -> Self {
430        Self::default()
431    }
432
433    /// An iterator of row addrs
434    ///
435    /// If there are any "full fragment" items then this can't be calculated and None
436    /// is returned
437    pub fn row_addrs(&self) -> Option<impl Iterator<Item = RowAddress> + '_> {
438        let inner_iters = self
439            .inner
440            .iter()
441            .filter_map(|(frag_id, row_addr_selection)| match row_addr_selection {
442                RowAddrSelection::Full => None,
443                RowAddrSelection::Partial(bitmap) => Some(
444                    bitmap
445                        .iter()
446                        .map(|row_offset| RowAddress::new_from_parts(*frag_id, row_offset)),
447                ),
448            })
449            .collect::<Vec<_>>();
450        if inner_iters.len() != self.inner.len() {
451            None
452        } else {
453            Some(inner_iters.into_iter().flatten())
454        }
455    }
456
457    /// Insert a single value into the set
458    ///
459    /// Returns true if the value was not already in the set.
460    ///
461    /// ```rust
462    /// use lance_core::utils::mask::{RowAddrTreeMap, RowSetOps};
463    ///
464    /// let mut set = RowAddrTreeMap::new();
465    /// assert_eq!(set.insert(10), true);
466    /// assert_eq!(set.insert(10), false);
467    /// assert_eq!(set.contains(10), true);
468    /// ```
469    pub fn insert(&mut self, value: u64) -> bool {
470        let fragment = (value >> 32) as u32;
471        let row_addr = value as u32;
472        match self.inner.get_mut(&fragment) {
473            None => {
474                let mut set = RoaringBitmap::new();
475                set.insert(row_addr);
476                self.inner.insert(fragment, RowAddrSelection::Partial(set));
477                true
478            }
479            Some(RowAddrSelection::Full) => false,
480            Some(RowAddrSelection::Partial(set)) => set.insert(row_addr),
481        }
482    }
483
484    /// Insert a range of values into the set
485    pub fn insert_range<R: RangeBounds<u64>>(&mut self, range: R) -> u64 {
486        // Separate the start and end into high and low bits.
487        let (mut start_high, mut start_low) = match range.start_bound() {
488            std::ops::Bound::Included(&start) => ((start >> 32) as u32, start as u32),
489            std::ops::Bound::Excluded(&start) => {
490                let start = start.saturating_add(1);
491                ((start >> 32) as u32, start as u32)
492            }
493            std::ops::Bound::Unbounded => (0, 0),
494        };
495
496        let (end_high, end_low) = match range.end_bound() {
497            std::ops::Bound::Included(&end) => ((end >> 32) as u32, end as u32),
498            std::ops::Bound::Excluded(&end) => {
499                let end = end.saturating_sub(1);
500                ((end >> 32) as u32, end as u32)
501            }
502            std::ops::Bound::Unbounded => (u32::MAX, u32::MAX),
503        };
504
505        let mut count = 0;
506
507        while start_high <= end_high {
508            let start = start_low;
509            let end = if start_high == end_high {
510                end_low
511            } else {
512                u32::MAX
513            };
514            let fragment = start_high;
515            match self.inner.get_mut(&fragment) {
516                None => {
517                    let mut set = RoaringBitmap::new();
518                    count += set.insert_range(start..=end);
519                    self.inner.insert(fragment, RowAddrSelection::Partial(set));
520                }
521                Some(RowAddrSelection::Full) => {}
522                Some(RowAddrSelection::Partial(set)) => {
523                    count += set.insert_range(start..=end);
524                }
525            }
526            start_high += 1;
527            start_low = 0;
528        }
529
530        count
531    }
532
533    /// Add a bitmap for a single fragment
534    pub fn insert_bitmap(&mut self, fragment: u32, bitmap: RoaringBitmap) {
535        self.inner
536            .insert(fragment, RowAddrSelection::Partial(bitmap));
537    }
538
539    /// Add a whole fragment to the set
540    pub fn insert_fragment(&mut self, fragment_id: u32) {
541        self.inner.insert(fragment_id, RowAddrSelection::Full);
542    }
543
544    pub fn get_fragment_bitmap(&self, fragment_id: u32) -> Option<&RoaringBitmap> {
545        match self.inner.get(&fragment_id) {
546            None => None,
547            Some(RowAddrSelection::Full) => None,
548            Some(RowAddrSelection::Partial(set)) => Some(set),
549        }
550    }
551
552    /// Get the selection for a fragment
553    pub fn get(&self, fragment_id: &u32) -> Option<&RowAddrSelection> {
554        self.inner.get(fragment_id)
555    }
556
557    /// Iterate over (fragment_id, selection) pairs
558    pub fn iter(&self) -> impl Iterator<Item = (&u32, &RowAddrSelection)> {
559        self.inner.iter()
560    }
561
562    pub fn retain_fragments(&mut self, frag_ids: impl IntoIterator<Item = u32>) {
563        let frag_id_set = frag_ids.into_iter().collect::<HashSet<_>>();
564        self.inner
565            .retain(|frag_id, _| frag_id_set.contains(frag_id));
566    }
567
568    /// Compute the serialized size of the set.
569    pub fn serialized_size(&self) -> usize {
570        // Starts at 4 because of the u32 num_entries
571        let mut size = 4;
572        for set in self.inner.values() {
573            // Each entry is 8 bytes for the fragment id and the bitmap size
574            size += 8;
575            if let RowAddrSelection::Partial(set) = set {
576                size += set.serialized_size();
577            }
578        }
579        size
580    }
581
582    /// Serialize the set into the given buffer
583    ///
584    /// The serialization format is stable and used for index serialization
585    ///
586    /// The serialization format is:
587    /// * u32: num_entries
588    ///
589    /// for each entry:
590    ///   * u32: fragment_id
591    ///   * u32: bitmap size
592    ///   * \[u8\]: bitmap
593    ///
594    /// If bitmap size is zero then the entire fragment is selected.
595    pub fn serialize_into<W: Write>(&self, mut writer: W) -> Result<()> {
596        writer.write_u32::<byteorder::LittleEndian>(self.inner.len() as u32)?;
597        for (fragment, set) in &self.inner {
598            writer.write_u32::<byteorder::LittleEndian>(*fragment)?;
599            if let RowAddrSelection::Partial(set) = set {
600                writer.write_u32::<byteorder::LittleEndian>(set.serialized_size() as u32)?;
601                set.serialize_into(&mut writer)?;
602            } else {
603                writer.write_u32::<byteorder::LittleEndian>(0)?;
604            }
605        }
606        Ok(())
607    }
608
609    /// Deserialize the set from the given buffer
610    pub fn deserialize_from<R: Read>(mut reader: R) -> Result<Self> {
611        let num_entries = reader.read_u32::<byteorder::LittleEndian>()?;
612        let mut inner = BTreeMap::new();
613        for _ in 0..num_entries {
614            let fragment = reader.read_u32::<byteorder::LittleEndian>()?;
615            let bitmap_size = reader.read_u32::<byteorder::LittleEndian>()?;
616            if bitmap_size == 0 {
617                inner.insert(fragment, RowAddrSelection::Full);
618            } else {
619                let mut buffer = vec![0; bitmap_size as usize];
620                reader.read_exact(&mut buffer)?;
621                let set = RoaringBitmap::deserialize_from(&buffer[..])?;
622                inner.insert(fragment, RowAddrSelection::Partial(set));
623            }
624        }
625        Ok(Self { inner })
626    }
627
628    /// Apply a mask to the row addrs
629    ///
630    /// For AllowList: only keep rows that are in the selection and not null
631    /// For BlockList: remove rows that are blocked (not null) and remove nulls
632    pub fn mask(&mut self, mask: &RowAddrMask) {
633        match mask {
634            RowAddrMask::AllowList(allow_list) => {
635                *self &= allow_list;
636            }
637            RowAddrMask::BlockList(block_list) => {
638                *self -= block_list;
639            }
640        }
641    }
642
643    /// Convert the set into an iterator of row addrs
644    ///
645    /// # Safety
646    ///
647    /// This is unsafe because if any of the inner RowAddrSelection elements
648    /// is not a Partial then the iterator will panic because we don't know
649    /// the size of the bitmap.
650    pub unsafe fn into_addr_iter(self) -> impl Iterator<Item = u64> {
651        self.inner
652            .into_iter()
653            .flat_map(|(fragment, selection)| match selection {
654                RowAddrSelection::Full => panic!("Size of full fragment is unknown"),
655                RowAddrSelection::Partial(bitmap) => bitmap.into_iter().map(move |val| {
656                    let fragment = fragment as u64;
657                    let row_offset = val as u64;
658                    (fragment << 32) | row_offset
659                }),
660            })
661    }
662}
663
664impl std::ops::BitOr<Self> for RowAddrTreeMap {
665    type Output = Self;
666
667    fn bitor(mut self, rhs: Self) -> Self::Output {
668        self |= rhs;
669        self
670    }
671}
672
673impl std::ops::BitOr<&Self> for RowAddrTreeMap {
674    type Output = Self;
675
676    fn bitor(mut self, rhs: &Self) -> Self::Output {
677        self |= rhs;
678        self
679    }
680}
681
682impl std::ops::BitOrAssign<Self> for RowAddrTreeMap {
683    fn bitor_assign(&mut self, rhs: Self) {
684        *self |= &rhs;
685    }
686}
687
688impl std::ops::BitOrAssign<&Self> for RowAddrTreeMap {
689    fn bitor_assign(&mut self, rhs: &Self) {
690        for (fragment, rhs_set) in &rhs.inner {
691            let lhs_set = self.inner.get_mut(fragment);
692            if let Some(lhs_set) = lhs_set {
693                match lhs_set {
694                    RowAddrSelection::Full => {
695                        // If the fragment is already selected then there is nothing to do
696                    }
697                    RowAddrSelection::Partial(lhs_bitmap) => match rhs_set {
698                        RowAddrSelection::Full => {
699                            *lhs_set = RowAddrSelection::Full;
700                        }
701                        RowAddrSelection::Partial(rhs_set) => {
702                            *lhs_bitmap |= rhs_set;
703                        }
704                    },
705                }
706            } else {
707                self.inner.insert(*fragment, rhs_set.clone());
708            }
709        }
710    }
711}
712
713impl std::ops::BitAnd<Self> for RowAddrTreeMap {
714    type Output = Self;
715
716    fn bitand(mut self, rhs: Self) -> Self::Output {
717        self &= &rhs;
718        self
719    }
720}
721
722impl std::ops::BitAnd<&Self> for RowAddrTreeMap {
723    type Output = Self;
724
725    fn bitand(mut self, rhs: &Self) -> Self::Output {
726        self &= rhs;
727        self
728    }
729}
730
731impl std::ops::BitAndAssign<Self> for RowAddrTreeMap {
732    fn bitand_assign(&mut self, rhs: Self) {
733        *self &= &rhs;
734    }
735}
736
737impl std::ops::BitAndAssign<&Self> for RowAddrTreeMap {
738    fn bitand_assign(&mut self, rhs: &Self) {
739        // Remove fragment that aren't on the RHS
740        self.inner
741            .retain(|fragment, _| rhs.inner.contains_key(fragment));
742
743        // For fragments that are on the RHS, intersect the bitmaps
744        for (fragment, mut lhs_set) in &mut self.inner {
745            match (&mut lhs_set, rhs.inner.get(fragment)) {
746                (_, None) => {} // Already handled by retain
747                (_, Some(RowAddrSelection::Full)) => {
748                    // Everything selected on RHS, so can leave LHS untouched.
749                }
750                (RowAddrSelection::Partial(lhs_set), Some(RowAddrSelection::Partial(rhs_set))) => {
751                    *lhs_set &= rhs_set;
752                }
753                (RowAddrSelection::Full, Some(RowAddrSelection::Partial(rhs_set))) => {
754                    *lhs_set = RowAddrSelection::Partial(rhs_set.clone());
755                }
756            }
757        }
758        // Some bitmaps might now be empty. If they are, we should remove them.
759        self.inner.retain(|_, set| match set {
760            RowAddrSelection::Partial(set) => !set.is_empty(),
761            RowAddrSelection::Full => true,
762        });
763    }
764}
765
766impl std::ops::Sub<Self> for RowAddrTreeMap {
767    type Output = Self;
768
769    fn sub(mut self, rhs: Self) -> Self {
770        self -= &rhs;
771        self
772    }
773}
774
775impl std::ops::Sub<&Self> for RowAddrTreeMap {
776    type Output = Self;
777
778    fn sub(mut self, rhs: &Self) -> Self {
779        self -= rhs;
780        self
781    }
782}
783
784impl std::ops::SubAssign<&Self> for RowAddrTreeMap {
785    fn sub_assign(&mut self, rhs: &Self) {
786        for (fragment, rhs_set) in &rhs.inner {
787            match self.inner.get_mut(fragment) {
788                None => {}
789                Some(RowAddrSelection::Full) => {
790                    // If the fragment is already selected then there is nothing to do
791                    match rhs_set {
792                        RowAddrSelection::Full => {
793                            self.inner.remove(fragment);
794                        }
795                        RowAddrSelection::Partial(rhs_set) => {
796                            // This generally won't be hit.
797                            let mut set = RoaringBitmap::full();
798                            set -= rhs_set;
799                            self.inner.insert(*fragment, RowAddrSelection::Partial(set));
800                        }
801                    }
802                }
803                Some(RowAddrSelection::Partial(lhs_set)) => match rhs_set {
804                    RowAddrSelection::Full => {
805                        self.inner.remove(fragment);
806                    }
807                    RowAddrSelection::Partial(rhs_set) => {
808                        *lhs_set -= rhs_set;
809                        if lhs_set.is_empty() {
810                            self.inner.remove(fragment);
811                        }
812                    }
813                },
814            }
815        }
816    }
817}
818
819impl FromIterator<u64> for RowAddrTreeMap {
820    fn from_iter<T: IntoIterator<Item = u64>>(iter: T) -> Self {
821        let mut inner = BTreeMap::new();
822        for row_addr in iter {
823            let upper = (row_addr >> 32) as u32;
824            let lower = row_addr as u32;
825            match inner.get_mut(&upper) {
826                None => {
827                    let mut set = RoaringBitmap::new();
828                    set.insert(lower);
829                    inner.insert(upper, RowAddrSelection::Partial(set));
830                }
831                Some(RowAddrSelection::Full) => {
832                    // If the fragment is already selected then there is nothing to do
833                }
834                Some(RowAddrSelection::Partial(set)) => {
835                    set.insert(lower);
836                }
837            }
838        }
839        Self { inner }
840    }
841}
842
843impl<'a> FromIterator<&'a u64> for RowAddrTreeMap {
844    fn from_iter<T: IntoIterator<Item = &'a u64>>(iter: T) -> Self {
845        Self::from_iter(iter.into_iter().copied())
846    }
847}
848
849impl From<Range<u64>> for RowAddrTreeMap {
850    fn from(range: Range<u64>) -> Self {
851        let mut map = Self::default();
852        map.insert_range(range);
853        map
854    }
855}
856
857impl From<RangeInclusive<u64>> for RowAddrTreeMap {
858    fn from(range: RangeInclusive<u64>) -> Self {
859        let mut map = Self::default();
860        map.insert_range(range);
861        map
862    }
863}
864
865impl From<RoaringTreemap> for RowAddrTreeMap {
866    fn from(roaring: RoaringTreemap) -> Self {
867        let mut inner = BTreeMap::new();
868        for (fragment, set) in roaring.bitmaps() {
869            inner.insert(fragment, RowAddrSelection::Partial(set.clone()));
870        }
871        Self { inner }
872    }
873}
874
875impl Extend<u64> for RowAddrTreeMap {
876    fn extend<T: IntoIterator<Item = u64>>(&mut self, iter: T) {
877        for row_addr in iter {
878            let upper = (row_addr >> 32) as u32;
879            let lower = row_addr as u32;
880            match self.inner.get_mut(&upper) {
881                None => {
882                    let mut set = RoaringBitmap::new();
883                    set.insert(lower);
884                    self.inner.insert(upper, RowAddrSelection::Partial(set));
885                }
886                Some(RowAddrSelection::Full) => {
887                    // If the fragment is already selected then there is nothing to do
888                }
889                Some(RowAddrSelection::Partial(set)) => {
890                    set.insert(lower);
891                }
892            }
893        }
894    }
895}
896
897impl<'a> Extend<&'a u64> for RowAddrTreeMap {
898    fn extend<T: IntoIterator<Item = &'a u64>>(&mut self, iter: T) {
899        self.extend(iter.into_iter().copied())
900    }
901}
902
903// Extending with RowAddrTreeMap is basically a cumulative set union
904impl Extend<Self> for RowAddrTreeMap {
905    fn extend<T: IntoIterator<Item = Self>>(&mut self, iter: T) {
906        for other in iter {
907            for (fragment, set) in other.inner {
908                match self.inner.get_mut(&fragment) {
909                    None => {
910                        self.inner.insert(fragment, set);
911                    }
912                    Some(RowAddrSelection::Full) => {
913                        // If the fragment is already selected then there is nothing to do
914                    }
915                    Some(RowAddrSelection::Partial(lhs_set)) => match set {
916                        RowAddrSelection::Full => {
917                            self.inner.insert(fragment, RowAddrSelection::Full);
918                        }
919                        RowAddrSelection::Partial(rhs_set) => {
920                            *lhs_set |= rhs_set;
921                        }
922                    },
923                }
924            }
925        }
926    }
927}
928
929pub fn bitmap_to_ranges(bitmap: &RoaringBitmap) -> Vec<Range<u64>> {
930    let mut ranges = Vec::new();
931    let mut iter = bitmap.iter();
932    while let Some(r) = iter.next_range() {
933        ranges.push(*r.start() as u64..(*r.end() as u64 + 1));
934    }
935    ranges
936}
937
938pub fn ranges_to_bitmap(ranges: &[Range<u64>], sorted: bool) -> RoaringBitmap {
939    if ranges.is_empty() {
940        return RoaringBitmap::new();
941    }
942    if sorted {
943        let sample_size = ranges.len().min(10);
944        let avg_len: u64 = ranges
945            .iter()
946            .take(sample_size)
947            .map(|r| r.end - r.start)
948            .sum::<u64>()
949            / sample_size as u64;
950        // from_sorted_iter appends each value in O(1) but must visit every u32.
951        // insert_range bulk-fills containers but does a binary search per call.
952        // Crossover is ~6: below that, iterating all values is cheaper.
953        if avg_len <= 6 {
954            return RoaringBitmap::from_sorted_iter(
955                ranges.iter().flat_map(|r| r.start as u32..r.end as u32),
956            )
957            .unwrap();
958        }
959    }
960    let mut bm = RoaringBitmap::new();
961    for r in ranges {
962        bm.insert_range(r.start as u32..r.end as u32);
963    }
964    bm
965}
966
967/// A set of stable row ids backed by a 64-bit Roaring bitmap.
968///
969/// This is a thin wrapper around [`RoaringTreemap`]. It represents a
970/// collection of unique row ids and provides the common row-set
971/// operations defined by [`RowSetOps`].
972#[derive(Clone, Debug, Default, PartialEq)]
973pub struct RowIdSet {
974    inner: RoaringTreemap,
975}
976
977impl RowIdSet {
978    /// Creates an empty set of row ids.
979    pub fn new() -> Self {
980        Self::default()
981    }
982    /// Returns an iterator over the contained row ids in ascending order.
983    pub fn iter(&self) -> impl Iterator<Item = u64> + '_ {
984        self.inner.iter()
985    }
986    /// Returns the union of `self` and `other`.
987    pub fn union(mut self, other: &Self) -> Self {
988        self.inner |= &other.inner;
989        self
990    }
991    /// Returns the set difference `self \\ other`.
992    pub fn difference(mut self, other: &Self) -> Self {
993        self.inner -= &other.inner;
994        self
995    }
996}
997
998impl RowSetOps for RowIdSet {
999    type Row = u64;
1000    fn is_empty(&self) -> bool {
1001        self.inner.is_empty()
1002    }
1003    fn len(&self) -> Option<u64> {
1004        Some(self.inner.len())
1005    }
1006    fn remove(&mut self, row: Self::Row) -> bool {
1007        self.inner.remove(row)
1008    }
1009    fn contains(&self, row: Self::Row) -> bool {
1010        self.inner.contains(row)
1011    }
1012    fn union_all(other: &[&Self]) -> Self {
1013        let mut result = other
1014            .first()
1015            .map_or(Self::default(), |&first| first.clone());
1016        for set in other {
1017            result.inner |= &set.inner;
1018        }
1019        result
1020    }
1021    #[track_caller]
1022    fn from_sorted_iter<I>(iter: I) -> Result<Self>
1023    where
1024        I: IntoIterator<Item = Self::Row>,
1025    {
1026        let mut inner = RoaringTreemap::new();
1027        let mut last: Option<u64> = None;
1028        for value in iter {
1029            if let Some(prev) = last
1030                && value < prev
1031            {
1032                return Err(Error::internal(
1033                    "RowIdSet::from_sorted_iter called with non-sorted input",
1034                ));
1035            }
1036            inner.insert(value);
1037            last = Some(value);
1038        }
1039        Ok(Self { inner })
1040    }
1041}
1042
1043/// A mask over stable row ids based on an allow-list or block-list.
1044///
1045/// The semantics mirror [`RowAddrMask`], but operate on stable
1046/// row ids instead of physical row addresses.
1047#[derive(Clone, Debug, PartialEq)]
1048pub enum RowIdMask {
1049    /// Only the ids in the set are selected.
1050    AllowList(RowIdSet),
1051    /// All ids are selected except those in the set.
1052    BlockList(RowIdSet),
1053}
1054
1055impl Default for RowIdMask {
1056    fn default() -> Self {
1057        // Empty block list means all rows are allowed
1058        Self::BlockList(RowIdSet::default())
1059    }
1060}
1061impl RowIdMask {
1062    /// Create a mask allowing all rows, this is an alias for [`Default`].
1063    pub fn all_rows() -> Self {
1064        Self::default()
1065    }
1066    /// Create a mask that doesn't allow any row id.
1067    pub fn allow_nothing() -> Self {
1068        Self::AllowList(RowIdSet::default())
1069    }
1070    /// Create a mask from an allow list.
1071    pub fn from_allowed(allow_list: RowIdSet) -> Self {
1072        Self::AllowList(allow_list)
1073    }
1074    /// Create a mask from a block list.
1075    pub fn from_block(block_list: RowIdSet) -> Self {
1076        Self::BlockList(block_list)
1077    }
1078    /// True if the row id is selected by the mask, false otherwise.
1079    pub fn selected(&self, row_id: u64) -> bool {
1080        match self {
1081            Self::AllowList(allow_list) => allow_list.contains(row_id),
1082            Self::BlockList(block_list) => !block_list.contains(row_id),
1083        }
1084    }
1085    /// Return the indices of the input row ids that are selected by the mask.
1086    pub fn selected_indices<'a>(&self, row_ids: impl Iterator<Item = &'a u64> + 'a) -> Vec<u64> {
1087        row_ids
1088            .enumerate()
1089            .filter_map(|(idx, row_id)| {
1090                if self.selected(*row_id) {
1091                    Some(idx as u64)
1092                } else {
1093                    None
1094                }
1095            })
1096            .collect()
1097    }
1098    /// Also block the given ids.
1099    ///
1100    /// * `AllowList(a)` -> `AllowList(a \\ block_list)`
1101    /// * `BlockList(b)` -> `BlockList(b union block_list)`
1102    pub fn also_block(self, block_list: RowIdSet) -> Self {
1103        match self {
1104            Self::AllowList(allow_list) => Self::AllowList(allow_list.difference(&block_list)),
1105            Self::BlockList(existing) => Self::BlockList(existing.union(&block_list)),
1106        }
1107    }
1108    /// Also allow the given ids.
1109    ///
1110    /// * `AllowList(a)` -> `AllowList(a union allow_list)`
1111    /// * `BlockList(b)` -> `BlockList(b \\ allow_list)`
1112    pub fn also_allow(self, allow_list: RowIdSet) -> Self {
1113        match self {
1114            Self::AllowList(existing) => Self::AllowList(existing.union(&allow_list)),
1115            Self::BlockList(block_list) => Self::BlockList(block_list.difference(&allow_list)),
1116        }
1117    }
1118    /// Return the maximum number of row ids that could be selected by this mask.
1119    ///
1120    /// Will be `None` if this is a `BlockList` (unbounded).
1121    pub fn max_len(&self) -> Option<u64> {
1122        match self {
1123            Self::AllowList(selection) => selection.len(),
1124            Self::BlockList(_) => None,
1125        }
1126    }
1127    /// Iterate over the row ids that are selected by the mask.
1128    ///
1129    /// This is only possible if this is an `AllowList`. For a `BlockList`
1130    /// the domain of possible row ids is unbounded.
1131    pub fn iter_ids(&self) -> Option<Box<dyn Iterator<Item = u64> + '_>> {
1132        match self {
1133            Self::AllowList(allow_list) => Some(Box::new(allow_list.iter())),
1134            Self::BlockList(_) => None,
1135        }
1136    }
1137}
1138
1139#[cfg(test)]
1140mod tests {
1141    use super::*;
1142    use proptest::{prop_assert, prop_assert_eq};
1143
1144    fn rows(ids: &[u64]) -> RowAddrTreeMap {
1145        RowAddrTreeMap::from_iter(ids)
1146    }
1147
1148    fn assert_mask_selects(mask: &RowAddrMask, selected: &[u64], not_selected: &[u64]) {
1149        for &id in selected {
1150            assert!(mask.selected(id), "Expected row {} to be selected", id);
1151        }
1152        for &id in not_selected {
1153            assert!(!mask.selected(id), "Expected row {} to NOT be selected", id);
1154        }
1155    }
1156
1157    fn selected_in_range(mask: &RowAddrMask, range: std::ops::Range<u64>) -> Vec<u64> {
1158        range.filter(|val| mask.selected(*val)).collect()
1159    }
1160
1161    #[test]
1162    fn test_row_addr_mask_construction() {
1163        let full_mask = RowAddrMask::all_rows();
1164        assert_eq!(full_mask.max_len(), None);
1165        assert_mask_selects(&full_mask, &[0, 1, 4 << 32 | 3], &[]);
1166        assert_eq!(full_mask.allow_list(), None);
1167        assert_eq!(full_mask.block_list(), Some(&RowAddrTreeMap::default()));
1168        assert!(full_mask.iter_addrs().is_none());
1169
1170        let empty_mask = RowAddrMask::allow_nothing();
1171        assert_eq!(empty_mask.max_len(), Some(0));
1172        assert_mask_selects(&empty_mask, &[], &[0, 1, 4 << 32 | 3]);
1173        assert_eq!(empty_mask.allow_list(), Some(&RowAddrTreeMap::default()));
1174        assert_eq!(empty_mask.block_list(), None);
1175        let iter = empty_mask.iter_addrs();
1176        assert!(iter.is_some());
1177        assert_eq!(iter.unwrap().count(), 0);
1178
1179        let allow_list = RowAddrMask::from_allowed(rows(&[10, 20, 30]));
1180        assert_eq!(allow_list.max_len(), Some(3));
1181        assert_mask_selects(&allow_list, &[10, 20, 30], &[0, 15, 25, 40]);
1182        assert_eq!(allow_list.allow_list(), Some(&rows(&[10, 20, 30])));
1183        assert_eq!(allow_list.block_list(), None);
1184        let iter = allow_list.iter_addrs();
1185        assert!(iter.is_some());
1186        let ids: Vec<u64> = iter.unwrap().map(|addr| addr.into()).collect();
1187        assert_eq!(ids, vec![10, 20, 30]);
1188
1189        let mut full_frag = RowAddrTreeMap::default();
1190        full_frag.insert_fragment(2);
1191        let allow_list = RowAddrMask::from_allowed(full_frag);
1192        assert_eq!(allow_list.max_len(), None);
1193        assert_mask_selects(&allow_list, &[(2 << 32) + 5], &[(3 << 32) + 5]);
1194        assert!(allow_list.iter_addrs().is_none());
1195    }
1196
1197    #[test]
1198    fn test_selected_indices() {
1199        // Allow list
1200        let mask = RowAddrMask::from_allowed(rows(&[10, 20, 40]));
1201        assert!(mask.selected_indices(std::iter::empty()).is_empty());
1202        assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[1, 3]);
1203
1204        // Block list
1205        let mask = RowAddrMask::from_block(rows(&[10, 20, 40]));
1206        assert!(mask.selected_indices(std::iter::empty()).is_empty());
1207        assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[0, 2]);
1208    }
1209
1210    #[test]
1211    fn test_also_allow() {
1212        // Allow list
1213        let mask = RowAddrMask::from_allowed(rows(&[10, 20]));
1214        let new_mask = mask.also_allow(rows(&[20, 30, 40]));
1215        assert_eq!(new_mask, RowAddrMask::from_allowed(rows(&[10, 20, 30, 40])));
1216
1217        // Block list
1218        let mask = RowAddrMask::from_block(rows(&[10, 20, 30]));
1219        let new_mask = mask.also_allow(rows(&[20, 40]));
1220        assert_eq!(new_mask, RowAddrMask::from_block(rows(&[10, 30])));
1221    }
1222
1223    #[test]
1224    fn test_also_block() {
1225        // Allow list
1226        let mask = RowAddrMask::from_allowed(rows(&[10, 20, 30]));
1227        let new_mask = mask.also_block(rows(&[20, 40]));
1228        assert_eq!(new_mask, RowAddrMask::from_allowed(rows(&[10, 30])));
1229
1230        // Block list
1231        let mask = RowAddrMask::from_block(rows(&[10, 20]));
1232        let new_mask = mask.also_block(rows(&[20, 30, 40]));
1233        assert_eq!(new_mask, RowAddrMask::from_block(rows(&[10, 20, 30, 40])));
1234    }
1235
1236    #[test]
1237    fn test_iter_ids() {
1238        // Allow list
1239        let mask = RowAddrMask::from_allowed(rows(&[10, 20, 30]));
1240        let expected: Vec<_> = [10, 20, 30].into_iter().map(RowAddress::from).collect();
1241        assert_eq!(mask.iter_addrs().unwrap().collect::<Vec<_>>(), expected);
1242
1243        // Allow list with full fragment
1244        let mut inner = RowAddrTreeMap::default();
1245        inner.insert_fragment(10);
1246        let mask = RowAddrMask::from_allowed(inner);
1247        assert!(mask.iter_addrs().is_none());
1248
1249        // Block list
1250        let mask = RowAddrMask::from_block(rows(&[10, 20, 30]));
1251        assert!(mask.iter_addrs().is_none());
1252    }
1253
1254    #[test]
1255    fn test_row_addr_mask_not() {
1256        let allow_list = RowAddrMask::from_allowed(rows(&[1, 2, 3]));
1257        let block_list = !allow_list.clone();
1258        assert_eq!(block_list, RowAddrMask::from_block(rows(&[1, 2, 3])));
1259        // Can roundtrip by negating again
1260        assert_eq!(!block_list, allow_list);
1261    }
1262
1263    #[test]
1264    fn test_ops() {
1265        let mask = RowAddrMask::default();
1266        assert_mask_selects(&mask, &[1, 5], &[]);
1267
1268        let block_list = mask.also_block(rows(&[0, 5, 15]));
1269        assert_mask_selects(&block_list, &[1], &[5]);
1270
1271        let allow_list = RowAddrMask::from_allowed(rows(&[0, 2, 5]));
1272        assert_mask_selects(&allow_list, &[5], &[1]);
1273
1274        let combined = block_list & allow_list;
1275        assert_mask_selects(&combined, &[2], &[0, 5]);
1276
1277        let other = RowAddrMask::from_allowed(rows(&[3]));
1278        let combined = combined | other;
1279        assert_mask_selects(&combined, &[2, 3], &[0, 5]);
1280
1281        let block_list = RowAddrMask::from_block(rows(&[0]));
1282        let allow_list = RowAddrMask::from_allowed(rows(&[3]));
1283
1284        let combined = block_list | allow_list;
1285        assert_mask_selects(&combined, &[1], &[]);
1286    }
1287
1288    #[test]
1289    fn test_logical_and() {
1290        let allow1 = RowAddrMask::from_allowed(rows(&[0, 1]));
1291        let block1 = RowAddrMask::from_block(rows(&[1, 2]));
1292        let allow2 = RowAddrMask::from_allowed(rows(&[1, 2, 3, 4]));
1293        let block2 = RowAddrMask::from_block(rows(&[3, 4]));
1294
1295        fn check(lhs: &RowAddrMask, rhs: &RowAddrMask, expected: &[u64]) {
1296            for mask in [lhs.clone() & rhs.clone(), rhs.clone() & lhs.clone()] {
1297                assert_eq!(selected_in_range(&mask, 0..10), expected);
1298            }
1299        }
1300
1301        // Allow & Allow
1302        check(&allow1, &allow1, &[0, 1]);
1303        check(&allow1, &allow2, &[1]);
1304
1305        // Block & Block
1306        check(&block1, &block1, &[0, 3, 4, 5, 6, 7, 8, 9]);
1307        check(&block1, &block2, &[0, 5, 6, 7, 8, 9]);
1308
1309        // Allow & Block
1310        check(&allow1, &block1, &[0]);
1311        check(&allow1, &block2, &[0, 1]);
1312        check(&allow2, &block1, &[3, 4]);
1313        check(&allow2, &block2, &[1, 2]);
1314    }
1315
1316    #[test]
1317    fn test_logical_or() {
1318        let allow1 = RowAddrMask::from_allowed(rows(&[5, 6, 7, 8, 9]));
1319        let block1 = RowAddrMask::from_block(rows(&[5, 6]));
1320        let mixed1 = allow1.clone().also_block(rows(&[5, 6]));
1321        let allow2 = RowAddrMask::from_allowed(rows(&[2, 3, 4, 5, 6, 7, 8]));
1322        let block2 = RowAddrMask::from_block(rows(&[4, 5]));
1323        let mixed2 = allow2.clone().also_block(rows(&[4, 5]));
1324
1325        fn check(lhs: &RowAddrMask, rhs: &RowAddrMask, expected: &[u64]) {
1326            for mask in [lhs.clone() | rhs.clone(), rhs.clone() | lhs.clone()] {
1327                assert_eq!(selected_in_range(&mask, 0..10), expected);
1328            }
1329        }
1330
1331        check(&allow1, &allow1, &[5, 6, 7, 8, 9]);
1332        check(&block1, &block1, &[0, 1, 2, 3, 4, 7, 8, 9]);
1333        check(&mixed1, &mixed1, &[7, 8, 9]);
1334        check(&allow2, &allow2, &[2, 3, 4, 5, 6, 7, 8]);
1335        check(&block2, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1336        check(&mixed2, &mixed2, &[2, 3, 6, 7, 8]);
1337
1338        check(&allow1, &block1, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1339        check(&allow1, &mixed1, &[5, 6, 7, 8, 9]);
1340        check(&allow1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
1341        check(&allow1, &block2, &[0, 1, 2, 3, 5, 6, 7, 8, 9]);
1342        check(&allow1, &mixed2, &[2, 3, 5, 6, 7, 8, 9]);
1343        check(&block1, &mixed1, &[0, 1, 2, 3, 4, 7, 8, 9]);
1344        check(&block1, &allow2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1345        check(&block1, &block2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1346        check(&block1, &mixed2, &[0, 1, 2, 3, 4, 6, 7, 8, 9]);
1347        check(&mixed1, &allow2, &[2, 3, 4, 5, 6, 7, 8, 9]);
1348        check(&mixed1, &block2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1349        check(&mixed1, &mixed2, &[2, 3, 6, 7, 8, 9]);
1350        check(&allow2, &block2, &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
1351        check(&allow2, &mixed2, &[2, 3, 4, 5, 6, 7, 8]);
1352        check(&block2, &mixed2, &[0, 1, 2, 3, 6, 7, 8, 9]);
1353    }
1354
1355    #[test]
1356    fn test_deserialize_legacy_format() {
1357        // Test that we can deserialize the old format where both allow_list
1358        // and block_list could be present in the serialized form.
1359        //
1360        // The old format (before this PR) used a struct with both allow_list and block_list
1361        // fields. The new format uses an enum. The deserialization code should handle
1362        // the case where both lists are present by converting to AllowList(allow - block).
1363
1364        // Create the RowIdTreeMaps and serialize them directly
1365        let allow = rows(&[1, 2, 3, 4, 5, 10, 15]);
1366        let block = rows(&[2, 4, 15]);
1367
1368        // Serialize using the stable RowIdTreeMap serialization format
1369        let block_bytes = {
1370            let mut buf = Vec::with_capacity(block.serialized_size());
1371            block.serialize_into(&mut buf).unwrap();
1372            buf
1373        };
1374        let allow_bytes = {
1375            let mut buf = Vec::with_capacity(allow.serialized_size());
1376            allow.serialize_into(&mut buf).unwrap();
1377            buf
1378        };
1379
1380        // Construct a binary array with both values present (simulating old format)
1381        let old_format_array =
1382            BinaryArray::from_opt_vec(vec![Some(&block_bytes), Some(&allow_bytes)]);
1383
1384        // Deserialize - should handle this by creating AllowList(allow - block)
1385        let deserialized = RowAddrMask::from_arrow(&old_format_array).unwrap();
1386
1387        // The expected result: AllowList([1, 2, 3, 4, 5, 10, 15] - [2, 4, 15]) = [1, 3, 5, 10]
1388        assert_mask_selects(&deserialized, &[1, 3, 5, 10], &[2, 4, 15]);
1389        assert!(
1390            deserialized.allow_list().is_some(),
1391            "Should deserialize to AllowList variant"
1392        );
1393    }
1394
1395    #[test]
1396    fn test_roundtrip_arrow() {
1397        let row_addrs = rows(&[1, 2, 3, 100, 2000]);
1398
1399        // Allow list
1400        let original = RowAddrMask::from_allowed(row_addrs.clone());
1401        let array = original.into_arrow().unwrap();
1402        assert_eq!(RowAddrMask::from_arrow(&array).unwrap(), original);
1403
1404        // Block list
1405        let original = RowAddrMask::from_block(row_addrs);
1406        let array = original.into_arrow().unwrap();
1407        assert_eq!(RowAddrMask::from_arrow(&array).unwrap(), original);
1408    }
1409
1410    #[test]
1411    fn test_deserialize_legacy_empty_lists() {
1412        // Case 1: Both None (should become all_rows)
1413        let array = BinaryArray::from_opt_vec(vec![None, None]);
1414        let mask = RowAddrMask::from_arrow(&array).unwrap();
1415        assert_mask_selects(&mask, &[0, 100, u64::MAX], &[]);
1416
1417        // Case 2: Only block list (no allow list)
1418        let block = rows(&[5, 10]);
1419        let block_bytes = {
1420            let mut buf = Vec::with_capacity(block.serialized_size());
1421            block.serialize_into(&mut buf).unwrap();
1422            buf
1423        };
1424        let array = BinaryArray::from_opt_vec(vec![Some(&block_bytes[..]), None]);
1425        let mask = RowAddrMask::from_arrow(&array).unwrap();
1426        assert_mask_selects(&mask, &[0, 15], &[5, 10]);
1427
1428        // Case 3: Only allow list (no block list)
1429        let allow = rows(&[5, 10]);
1430        let allow_bytes = {
1431            let mut buf = Vec::with_capacity(allow.serialized_size());
1432            allow.serialize_into(&mut buf).unwrap();
1433            buf
1434        };
1435        let array = BinaryArray::from_opt_vec(vec![None, Some(&allow_bytes[..])]);
1436        let mask = RowAddrMask::from_arrow(&array).unwrap();
1437        assert_mask_selects(&mask, &[5, 10], &[0, 15]);
1438    }
1439
1440    #[test]
1441    fn test_map_insert() {
1442        let mut map = RowAddrTreeMap::default();
1443
1444        assert!(!map.contains(20));
1445        assert!(map.insert(20));
1446        assert!(map.contains(20));
1447        assert!(!map.insert(20)); // Inserting again should be no-op
1448
1449        let bitmap = map.get_fragment_bitmap(0);
1450        assert!(bitmap.is_some());
1451        let bitmap = bitmap.unwrap();
1452        assert_eq!(bitmap.len(), 1);
1453
1454        assert!(map.get_fragment_bitmap(1).is_none());
1455
1456        map.insert_fragment(0);
1457        assert!(map.contains(0));
1458        assert!(!map.insert(0)); // Inserting into full fragment should be no-op
1459        assert!(map.get_fragment_bitmap(0).is_none());
1460    }
1461
1462    #[test]
1463    fn test_map_insert_range() {
1464        let ranges = &[
1465            (0..10),
1466            (40..500),
1467            ((u32::MAX as u64 - 10)..(u32::MAX as u64 + 20)),
1468        ];
1469
1470        for range in ranges {
1471            let mut mask = RowAddrTreeMap::default();
1472
1473            let count = mask.insert_range(range.clone());
1474            let expected = range.end - range.start;
1475            assert_eq!(count, expected);
1476
1477            let count = mask.insert_range(range.clone());
1478            assert_eq!(count, 0);
1479
1480            let new_range = range.start + 5..range.end + 5;
1481            let count = mask.insert_range(new_range.clone());
1482            assert_eq!(count, 5);
1483        }
1484
1485        let mut mask = RowAddrTreeMap::default();
1486        let count = mask.insert_range(..10);
1487        assert_eq!(count, 10);
1488        assert!(mask.contains(0));
1489
1490        let count = mask.insert_range(20..=24);
1491        assert_eq!(count, 5);
1492
1493        mask.insert_fragment(0);
1494        let count = mask.insert_range(100..200);
1495        assert_eq!(count, 0);
1496    }
1497
1498    #[test]
1499    fn test_map_remove() {
1500        let mut mask = RowAddrTreeMap::default();
1501
1502        assert!(!mask.remove(20));
1503
1504        mask.insert(20);
1505        assert!(mask.contains(20));
1506        assert!(mask.remove(20));
1507        assert!(!mask.contains(20));
1508
1509        mask.insert_range(10..=20);
1510        assert!(mask.contains(15));
1511        assert!(mask.remove(15));
1512        assert!(!mask.contains(15));
1513
1514        // We don't test removing from a full fragment, because that would take
1515        // a lot of memory.
1516    }
1517
1518    #[test]
1519    fn test_map_mask() {
1520        let mask = rows(&[0, 1, 2]);
1521        let mask2 = rows(&[0, 2, 3]);
1522
1523        let allow_list = RowAddrMask::AllowList(mask2.clone());
1524        let mut actual = mask.clone();
1525        actual.mask(&allow_list);
1526        assert_eq!(actual, rows(&[0, 2]));
1527
1528        let block_list = RowAddrMask::BlockList(mask2);
1529        let mut actual = mask;
1530        actual.mask(&block_list);
1531        assert_eq!(actual, rows(&[1]));
1532    }
1533
1534    #[test]
1535    #[should_panic(expected = "Size of full fragment is unknown")]
1536    fn test_map_insert_full_fragment_row() {
1537        let mut mask = RowAddrTreeMap::default();
1538        mask.insert_fragment(0);
1539
1540        unsafe {
1541            let _ = mask.into_addr_iter().collect::<Vec<u64>>();
1542        }
1543    }
1544
1545    #[test]
1546    fn test_map_into_addr_iter() {
1547        let mut mask = RowAddrTreeMap::default();
1548        mask.insert(0);
1549        mask.insert(1);
1550        mask.insert(1 << 32 | 5);
1551        mask.insert(2 << 32 | 10);
1552
1553        let expected = vec![0u64, 1, 1 << 32 | 5, 2 << 32 | 10];
1554        let actual: Vec<u64> = unsafe { mask.into_addr_iter().collect() };
1555        assert_eq!(actual, expected);
1556    }
1557
1558    #[test]
1559    fn test_map_from() {
1560        let map = RowAddrTreeMap::from(10..12);
1561        assert!(map.contains(10));
1562        assert!(map.contains(11));
1563        assert!(!map.contains(12));
1564        assert!(!map.contains(3));
1565
1566        let map = RowAddrTreeMap::from(10..=12);
1567        assert!(map.contains(10));
1568        assert!(map.contains(11));
1569        assert!(map.contains(12));
1570        assert!(!map.contains(3));
1571    }
1572
1573    #[test]
1574    fn test_map_from_roaring() {
1575        let bitmap = RoaringTreemap::from_iter(&[0, 1, 1 << 32]);
1576        let map = RowAddrTreeMap::from(bitmap);
1577        assert!(map.contains(0) && map.contains(1) && map.contains(1 << 32));
1578        assert!(!map.contains(2));
1579    }
1580
1581    #[test]
1582    fn test_map_extend() {
1583        let mut map = RowAddrTreeMap::default();
1584        map.insert(0);
1585        map.insert_fragment(1);
1586
1587        let other_rows = [0, 2, 1 << 32 | 10, 3 << 32 | 5];
1588        map.extend(other_rows.iter().copied());
1589
1590        assert!(map.contains(0));
1591        assert!(map.contains(2));
1592        assert!(map.contains(1 << 32 | 5));
1593        assert!(map.contains(1 << 32 | 10));
1594        assert!(map.contains(3 << 32 | 5));
1595        assert!(!map.contains(3));
1596    }
1597
1598    #[test]
1599    fn test_map_extend_other_maps() {
1600        let mut map = RowAddrTreeMap::default();
1601        map.insert(0);
1602        map.insert_fragment(1);
1603        map.insert(4 << 32);
1604
1605        let mut other_map = rows(&[0, 2, 1 << 32 | 10, 3 << 32 | 5]);
1606        other_map.insert_fragment(4);
1607        map.extend(std::iter::once(other_map));
1608
1609        for id in [
1610            0,
1611            2,
1612            1 << 32 | 5,
1613            1 << 32 | 10,
1614            3 << 32 | 5,
1615            4 << 32,
1616            4 << 32 | 7,
1617        ] {
1618            assert!(map.contains(id), "Expected {} to be contained", id);
1619        }
1620        assert!(!map.contains(3));
1621    }
1622
1623    proptest::proptest! {
1624        #[test]
1625        fn test_map_serialization_roundtrip(
1626            values in proptest::collection::vec(
1627                (0..u32::MAX, proptest::option::of(proptest::collection::vec(0..u32::MAX, 0..1000))),
1628                0..10
1629            )
1630        ) {
1631            let mut mask = RowAddrTreeMap::default();
1632            for (fragment, rows) in values {
1633                if let Some(rows) = rows {
1634                    let bitmap = RoaringBitmap::from_iter(rows);
1635                    mask.insert_bitmap(fragment, bitmap);
1636                } else {
1637                    mask.insert_fragment(fragment);
1638                }
1639            }
1640
1641            let mut data = Vec::new();
1642            mask.serialize_into(&mut data).unwrap();
1643            let deserialized = RowAddrTreeMap::deserialize_from(data.as_slice()).unwrap();
1644            prop_assert_eq!(mask, deserialized);
1645        }
1646
1647        #[test]
1648        fn test_map_intersect(
1649            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1650            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1651            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1652            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1653        ) {
1654            let mut left = RowAddrTreeMap::default();
1655            for fragment in left_full_fragments.clone() {
1656                left.insert_fragment(fragment);
1657            }
1658            left.extend(left_rows.iter().copied());
1659
1660            let mut right = RowAddrTreeMap::default();
1661            for fragment in right_full_fragments.clone() {
1662                right.insert_fragment(fragment);
1663            }
1664            right.extend(right_rows.iter().copied());
1665
1666            let mut expected = RowAddrTreeMap::default();
1667            for fragment in &left_full_fragments {
1668                if right_full_fragments.contains(fragment) {
1669                    expected.insert_fragment(*fragment);
1670                }
1671            }
1672
1673            let left_in_right = left_rows.iter().filter(|row| {
1674                right_rows.contains(row)
1675                    || right_full_fragments.contains(&((*row >> 32) as u32))
1676            });
1677            expected.extend(left_in_right);
1678            let right_in_left = right_rows.iter().filter(|row| {
1679                left_rows.contains(row)
1680                    || left_full_fragments.contains(&((*row >> 32) as u32))
1681            });
1682            expected.extend(right_in_left);
1683
1684            let actual = left & right;
1685            prop_assert_eq!(expected, actual);
1686        }
1687
1688        #[test]
1689        fn test_map_union(
1690            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1691            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1692            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1693            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1694        ) {
1695            let mut left = RowAddrTreeMap::default();
1696            for fragment in left_full_fragments.clone() {
1697                left.insert_fragment(fragment);
1698            }
1699            left.extend(left_rows.iter().copied());
1700
1701            let mut right = RowAddrTreeMap::default();
1702            for fragment in right_full_fragments.clone() {
1703                right.insert_fragment(fragment);
1704            }
1705            right.extend(right_rows.iter().copied());
1706
1707            let mut expected = RowAddrTreeMap::default();
1708            for fragment in left_full_fragments {
1709                expected.insert_fragment(fragment);
1710            }
1711            for fragment in right_full_fragments {
1712                expected.insert_fragment(fragment);
1713            }
1714
1715            let combined_rows = left_rows.iter().chain(right_rows.iter());
1716            expected.extend(combined_rows);
1717
1718            let actual = left | right;
1719            for actual_key_val in &actual.inner {
1720                proptest::prop_assert!(expected.inner.contains_key(actual_key_val.0));
1721                let expected_val = expected.inner.get(actual_key_val.0).unwrap();
1722                prop_assert_eq!(
1723                    actual_key_val.1,
1724                    expected_val,
1725                    "error on key {}",
1726                    actual_key_val.0
1727                );
1728            }
1729            prop_assert_eq!(expected, actual);
1730        }
1731
1732        #[test]
1733        fn test_map_subassign_rows(
1734            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1735            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1736            right_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1737        ) {
1738            let mut left = RowAddrTreeMap::default();
1739            for fragment in left_full_fragments {
1740                left.insert_fragment(fragment);
1741            }
1742            left.extend(left_rows.iter().copied());
1743
1744            let mut right = RowAddrTreeMap::default();
1745            right.extend(right_rows.iter().copied());
1746
1747            let mut expected = left.clone();
1748            for row in right_rows {
1749                expected.remove(row);
1750            }
1751
1752            left -= &right;
1753            prop_assert_eq!(expected, left);
1754        }
1755
1756        #[test]
1757        fn test_map_subassign_frags(
1758            left_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1759            right_full_fragments in proptest::collection::vec(0..u32::MAX, 0..10),
1760            left_rows in proptest::collection::vec(0..u64::MAX, 0..1000),
1761        ) {
1762            let mut left = RowAddrTreeMap::default();
1763            for fragment in left_full_fragments {
1764                left.insert_fragment(fragment);
1765            }
1766            left.extend(left_rows.iter().copied());
1767
1768            let mut right = RowAddrTreeMap::default();
1769            for fragment in right_full_fragments.clone() {
1770                right.insert_fragment(fragment);
1771            }
1772
1773            let mut expected = left.clone();
1774            for fragment in right_full_fragments {
1775                expected.inner.remove(&fragment);
1776            }
1777
1778            left -= &right;
1779            prop_assert_eq!(expected, left);
1780        }
1781
1782        #[test]
1783        fn test_from_sorted_iter(
1784            mut rows in proptest::collection::vec(0..u64::MAX, 0..1000)
1785        ) {
1786            rows.sort();
1787            let num_rows = rows.len();
1788            let mask = RowAddrTreeMap::from_sorted_iter(rows).unwrap();
1789            prop_assert_eq!(mask.len(), Some(num_rows as u64));
1790        }
1791
1792
1793    }
1794
1795    #[test]
1796    fn test_row_addr_selection_deep_size_of() {
1797        use deepsize::DeepSizeOf;
1798
1799        // Test Full variant - should have minimal size (just the enum discriminant)
1800        let full = RowAddrSelection::Full;
1801        let full_size = full.deep_size_of();
1802        // Full variant has no heap allocations beyond the enum itself
1803        assert!(full_size < 100); // Small sanity check
1804
1805        // Test Partial variant - should include bitmap size
1806        let mut bitmap = RoaringBitmap::new();
1807        bitmap.insert_range(0..100);
1808        let partial = RowAddrSelection::Partial(bitmap.clone());
1809        let partial_size = partial.deep_size_of();
1810        // Partial variant should be larger due to bitmap
1811        assert!(partial_size >= bitmap.serialized_size());
1812    }
1813
1814    #[test]
1815    fn test_row_addr_selection_union_all_with_full() {
1816        let full = RowAddrSelection::Full;
1817        let partial = RowAddrSelection::Partial(RoaringBitmap::from_iter(&[1, 2, 3]));
1818
1819        assert!(matches!(
1820            RowAddrSelection::union_all(&[&full, &partial]),
1821            RowAddrSelection::Full
1822        ));
1823
1824        let partial2 = RowAddrSelection::Partial(RoaringBitmap::from_iter(&[4, 5, 6]));
1825        let RowAddrSelection::Partial(bitmap) = RowAddrSelection::union_all(&[&partial, &partial2])
1826        else {
1827            panic!("Expected Partial");
1828        };
1829        assert!(bitmap.contains(1) && bitmap.contains(4));
1830    }
1831
1832    #[test]
1833    fn test_insert_range_unbounded_start() {
1834        let mut map = RowAddrTreeMap::default();
1835
1836        // Test exclusive start bound
1837        let count = map.insert_range((std::ops::Bound::Excluded(5), std::ops::Bound::Included(10)));
1838        assert_eq!(count, 5); // 6, 7, 8, 9, 10
1839        assert!(!map.contains(5));
1840        assert!(map.contains(6));
1841        assert!(map.contains(10));
1842
1843        // Test unbounded end
1844        let mut map2 = RowAddrTreeMap::default();
1845        let count = map2.insert_range(0..5);
1846        assert_eq!(count, 5);
1847        assert!(map2.contains(0));
1848        assert!(map2.contains(4));
1849        assert!(!map2.contains(5));
1850    }
1851
1852    #[test]
1853    fn test_remove_from_full_fragment() {
1854        let mut map = RowAddrTreeMap::default();
1855        map.insert_fragment(0);
1856
1857        // Verify it's a full fragment - get_fragment_bitmap returns None for Full
1858        for id in [0, 100, u32::MAX as u64] {
1859            assert!(map.contains(id));
1860        }
1861        assert!(map.get_fragment_bitmap(0).is_none());
1862
1863        // Remove a value from the full fragment
1864        assert!(map.remove(50));
1865
1866        // Now it should be partial (a full RoaringBitmap minus one value)
1867        assert!(map.contains(0) && !map.contains(50) && map.contains(100));
1868        assert!(map.get_fragment_bitmap(0).is_some());
1869    }
1870
1871    #[test]
1872    fn test_retain_fragments() {
1873        let mut map = RowAddrTreeMap::default();
1874        map.insert(0); // fragment 0
1875        map.insert(1 << 32 | 5); // fragment 1
1876        map.insert(2 << 32 | 10); // fragment 2
1877        map.insert_fragment(3); // fragment 3
1878
1879        map.retain_fragments([0, 2]);
1880
1881        assert!(map.contains(0) && map.contains(2 << 32 | 10));
1882        assert!(!map.contains(1 << 32 | 5) && !map.contains(3 << 32));
1883    }
1884
1885    #[test]
1886    fn test_bitor_assign_full_fragment() {
1887        // Test BitOrAssign when LHS has Full and RHS has Partial
1888        let mut map1 = RowAddrTreeMap::default();
1889        map1.insert_fragment(0);
1890        let mut map2 = RowAddrTreeMap::default();
1891        map2.insert(5);
1892
1893        map1 |= &map2;
1894        // Full | Partial = Full
1895        assert!(map1.contains(0) && map1.contains(5) && map1.contains(100));
1896
1897        // Test BitOrAssign when LHS has Partial and RHS has Full
1898        let mut map3 = RowAddrTreeMap::default();
1899        map3.insert(5);
1900        let mut map4 = RowAddrTreeMap::default();
1901        map4.insert_fragment(0);
1902
1903        map3 |= &map4;
1904        // Partial | Full = Full
1905        assert!(map3.contains(0) && map3.contains(5) && map3.contains(100));
1906    }
1907
1908    #[test]
1909    fn test_bitand_assign_full_fragments() {
1910        // Test BitAndAssign when both have Full for same fragment
1911        let mut map1 = RowAddrTreeMap::default();
1912        map1.insert_fragment(0);
1913        let mut map2 = RowAddrTreeMap::default();
1914        map2.insert_fragment(0);
1915
1916        map1 &= &map2;
1917        // Full & Full = Full
1918        assert!(map1.contains(0) && map1.contains(100));
1919
1920        // Test BitAndAssign when LHS Full, RHS Partial
1921        let mut map3 = RowAddrTreeMap::default();
1922        map3.insert_fragment(0);
1923        let mut map4 = RowAddrTreeMap::default();
1924        map4.insert(5);
1925        map4.insert(10);
1926
1927        map3 &= &map4;
1928        // Full & Partial([5,10]) = Partial([5,10])
1929        assert!(map3.contains(5) && map3.contains(10));
1930        assert!(!map3.contains(0) && !map3.contains(100));
1931
1932        // Test that empty intersection results in removal
1933        let mut map5 = RowAddrTreeMap::default();
1934        map5.insert(5);
1935        let mut map6 = RowAddrTreeMap::default();
1936        map6.insert(10);
1937
1938        map5 &= &map6;
1939        assert!(map5.is_empty());
1940    }
1941
1942    #[test]
1943    fn test_sub_assign_with_full_fragments() {
1944        // Test SubAssign when LHS is Full and RHS is Partial
1945        let mut map1 = RowAddrTreeMap::default();
1946        map1.insert_fragment(0);
1947        let mut map2 = RowAddrTreeMap::default();
1948        map2.insert(5);
1949        map2.insert(10);
1950
1951        map1 -= &map2;
1952        // Full - Partial([5,10]) = Full minus those values
1953        assert!(map1.contains(0) && map1.contains(100));
1954        assert!(!map1.contains(5) && !map1.contains(10));
1955
1956        // Test SubAssign when both are Full for same fragment
1957        let mut map3 = RowAddrTreeMap::default();
1958        map3.insert_fragment(0);
1959        let mut map4 = RowAddrTreeMap::default();
1960        map4.insert_fragment(0);
1961
1962        map3 -= &map4;
1963        // Full - Full = empty
1964        assert!(map3.is_empty());
1965
1966        // Test SubAssign when LHS is Partial and RHS is Full
1967        let mut map5 = RowAddrTreeMap::default();
1968        map5.insert(5);
1969        map5.insert(10);
1970        let mut map6 = RowAddrTreeMap::default();
1971        map6.insert_fragment(0);
1972
1973        map5 -= &map6;
1974        // Partial - Full = empty
1975        assert!(map5.is_empty());
1976    }
1977
1978    #[test]
1979    fn test_from_iterator_with_full_fragment() {
1980        // Test that inserting into a full fragment is a no-op
1981        let mut map = RowAddrTreeMap::default();
1982        map.insert_fragment(0);
1983
1984        // Extend with values that would go into fragment 0
1985        map.extend([5u64, 10, 100].iter());
1986
1987        // Should still be full fragment
1988        for id in [0, 5, 10, 100, u32::MAX as u64] {
1989            assert!(map.contains(id));
1990        }
1991    }
1992
1993    #[test]
1994    fn test_insert_range_excluded_end() {
1995        // Test excluded end bound (line 391-393)
1996        let mut map = RowAddrTreeMap::default();
1997        // Using RangeFrom with small range won't hit the unbounded case
1998        // Instead test Bound::Excluded for end
1999        let count = map.insert_range((std::ops::Bound::Included(5), std::ops::Bound::Excluded(10)));
2000        assert_eq!(count, 5); // 5, 6, 7, 8, 9
2001        assert!(map.contains(5));
2002        assert!(map.contains(9));
2003        assert!(!map.contains(10));
2004    }
2005
2006    #[test]
2007    fn test_bitand_assign_owned() {
2008        // Test BitAndAssign<Self> (owned, not reference)
2009        let mut map1 = RowAddrTreeMap::default();
2010        map1.insert(5);
2011        map1.insert(10);
2012
2013        // Using owned rhs (not reference)
2014        map1 &= rows(&[5, 15]);
2015
2016        assert!(map1.contains(5));
2017        assert!(!map1.contains(10) && !map1.contains(15));
2018    }
2019
2020    #[test]
2021    fn test_from_iter_with_full_fragment() {
2022        // When we collect into RowAddrTreeMap, it should handle duplicates
2023        let map: RowAddrTreeMap = vec![5u64, 10, 100].into_iter().collect();
2024        assert!(map.contains(5) && map.contains(10));
2025
2026        // Test that extending a map with full fragment ignores new values
2027        let mut map = RowAddrTreeMap::default();
2028        map.insert_fragment(0);
2029        for val in [5, 10, 100] {
2030            map.insert(val); // This should be no-op since fragment is full
2031        }
2032        // Still full fragment
2033        for id in [0, 5, u32::MAX as u64] {
2034            assert!(map.contains(id));
2035        }
2036    }
2037
2038    // ============================================================================
2039    // Tests for bitmap_to_ranges / ranges_to_bitmap
2040    // ============================================================================
2041
2042    #[test]
2043    fn test_bitmap_to_ranges_empty() {
2044        let bm = RoaringBitmap::new();
2045        assert!(bitmap_to_ranges(&bm).is_empty());
2046    }
2047
2048    #[test]
2049    fn test_bitmap_to_ranges_single() {
2050        let bm = RoaringBitmap::from_iter([5]);
2051        assert_eq!(bitmap_to_ranges(&bm), vec![5..6]);
2052    }
2053
2054    #[test]
2055    fn test_bitmap_to_ranges_contiguous() {
2056        let mut bm = RoaringBitmap::new();
2057        bm.insert_range(10..20);
2058        assert_eq!(bitmap_to_ranges(&bm), vec![10..20]);
2059    }
2060
2061    #[test]
2062    fn test_bitmap_to_ranges_multiple() {
2063        let mut bm = RoaringBitmap::new();
2064        bm.insert_range(0..3);
2065        bm.insert_range(10..15);
2066        bm.insert(100);
2067        assert_eq!(bitmap_to_ranges(&bm), vec![0..3, 10..15, 100..101]);
2068    }
2069
2070    #[test]
2071    fn test_ranges_to_bitmap_empty() {
2072        let bm = ranges_to_bitmap(&[], true);
2073        assert!(bm.is_empty());
2074    }
2075
2076    #[test]
2077    fn test_ranges_to_bitmap_sorted_short_ranges() {
2078        // avg len = 1, uses from_sorted_iter path
2079        let ranges = vec![0..1, 5..6, 10..11];
2080        let bm = ranges_to_bitmap(&ranges, true);
2081        assert!(bm.contains(0) && bm.contains(5) && bm.contains(10));
2082        assert_eq!(bm.len(), 3);
2083    }
2084
2085    #[test]
2086    fn test_ranges_to_bitmap_sorted_long_ranges() {
2087        // avg len = 100, uses insert_range path
2088        let ranges = vec![0..100, 200..300];
2089        let bm = ranges_to_bitmap(&ranges, true);
2090        assert_eq!(bm.len(), 200);
2091        assert!(bm.contains(0) && bm.contains(99));
2092        assert!(!bm.contains(100));
2093        assert!(bm.contains(200) && bm.contains(299));
2094    }
2095
2096    #[test]
2097    fn test_ranges_to_bitmap_unsorted() {
2098        let ranges = vec![200..300, 0..100];
2099        let bm = ranges_to_bitmap(&ranges, false);
2100        assert_eq!(bm.len(), 200);
2101        assert!(bm.contains(0) && bm.contains(250));
2102    }
2103
2104    #[test]
2105    fn test_bitmap_ranges_roundtrip() {
2106        let mut original = RoaringBitmap::new();
2107        original.insert_range(0..50);
2108        original.insert_range(100..200);
2109        original.insert(500);
2110        original.insert_range(1000..1010);
2111
2112        let ranges = bitmap_to_ranges(&original);
2113        let reconstructed = ranges_to_bitmap(&ranges, true);
2114        assert_eq!(original, reconstructed);
2115    }
2116
2117    // ============================================================================
2118    // Tests for RowIdSet
2119    // ============================================================================
2120
2121    fn row_ids(ids: &[u64]) -> RowIdSet {
2122        let mut set = RowIdSet::new();
2123        for &id in ids {
2124            set.inner.insert(id);
2125        }
2126        set
2127    }
2128
2129    #[test]
2130    fn test_row_id_set_construction() {
2131        let set = RowIdSet::new();
2132        assert!(set.is_empty());
2133        assert_eq!(set.len(), Some(0));
2134
2135        let set = row_ids(&[10, 20, 30]);
2136        assert!(!set.is_empty());
2137        assert_eq!(set.len(), Some(3));
2138        assert!(set.contains(10));
2139        assert!(set.contains(20));
2140        assert!(set.contains(30));
2141        assert!(!set.contains(15));
2142    }
2143
2144    #[test]
2145    fn test_row_id_set_remove() {
2146        let mut set = row_ids(&[10, 20, 30]);
2147
2148        assert!(!set.remove(15)); // Not present
2149        assert_eq!(set.len(), Some(3));
2150
2151        assert!(set.remove(20)); // Present
2152        assert_eq!(set.len(), Some(2));
2153        assert!(!set.contains(20));
2154        assert!(set.contains(10));
2155        assert!(set.contains(30));
2156
2157        assert!(!set.remove(20)); // Already removed
2158    }
2159
2160    #[test]
2161    fn test_row_id_set_union() {
2162        let set1 = row_ids(&[10, 20, 30]);
2163        let set2 = row_ids(&[20, 30, 40]);
2164
2165        let result = set1.union(&set2);
2166        assert_eq!(result.len(), Some(4));
2167        for id in [10, 20, 30, 40] {
2168            assert!(result.contains(id));
2169        }
2170    }
2171
2172    #[test]
2173    fn test_row_id_set_difference() {
2174        let set1 = row_ids(&[10, 20, 30, 40]);
2175        let set2 = row_ids(&[20, 40]);
2176
2177        let result = set1.difference(&set2);
2178        assert_eq!(result.len(), Some(2));
2179        assert!(result.contains(10));
2180        assert!(result.contains(30));
2181        assert!(!result.contains(20));
2182        assert!(!result.contains(40));
2183    }
2184
2185    #[test]
2186    fn test_row_id_set_union_all() {
2187        let set1 = row_ids(&[10, 20]);
2188        let set2 = row_ids(&[20, 30]);
2189        let set3 = row_ids(&[30, 40]);
2190
2191        let result = RowIdSet::union_all(&[&set1, &set2, &set3]);
2192        assert_eq!(result.len(), Some(4));
2193        for id in [10, 20, 30, 40] {
2194            assert!(result.contains(id));
2195        }
2196
2197        // Empty slice should return empty set
2198        let result = RowIdSet::union_all(&[]);
2199        assert!(result.is_empty());
2200    }
2201
2202    #[test]
2203    fn test_row_id_set_iter() {
2204        let set = row_ids(&[10, 20, 30]);
2205        let collected: Vec<u64> = set.iter().collect();
2206        assert_eq!(collected, vec![10, 20, 30]);
2207
2208        let empty = RowIdSet::new();
2209        assert_eq!(empty.iter().count(), 0);
2210    }
2211
2212    #[test]
2213    fn test_row_id_set_from_sorted_iter() {
2214        // Valid sorted input
2215        let set = RowIdSet::from_sorted_iter([10, 20, 30, 40]).unwrap();
2216        assert_eq!(set.len(), Some(4));
2217        for id in [10, 20, 30, 40] {
2218            assert!(set.contains(id));
2219        }
2220
2221        // Empty iterator
2222        let set = RowIdSet::from_sorted_iter(std::iter::empty()).unwrap();
2223        assert!(set.is_empty());
2224
2225        // Single element
2226        let set = RowIdSet::from_sorted_iter([42]).unwrap();
2227        assert_eq!(set.len(), Some(1));
2228        assert!(set.contains(42));
2229    }
2230
2231    #[test]
2232    fn test_row_id_set_from_sorted_iter_unsorted() {
2233        // Non-sorted input should return error
2234        let result = RowIdSet::from_sorted_iter([30, 10, 20]);
2235        assert!(result.is_err());
2236        assert!(result.unwrap_err().to_string().contains("non-sorted"));
2237    }
2238
2239    #[test]
2240    fn test_row_id_set_large_values() {
2241        // Test with large u64 values
2242        let large_ids = [u64::MAX - 10, u64::MAX - 5, u64::MAX - 1];
2243        let set = row_ids(&large_ids);
2244
2245        for &id in &large_ids {
2246            assert!(set.contains(id));
2247        }
2248        assert!(!set.contains(u64::MAX));
2249        assert_eq!(set.len(), Some(3));
2250    }
2251
2252    // ============================================================================
2253    // Tests for RowIdMask
2254    // ============================================================================
2255
2256    fn assert_row_id_mask_selects(mask: &RowIdMask, selected: &[u64], not_selected: &[u64]) {
2257        for &id in selected {
2258            assert!(mask.selected(id), "Expected row id {} to be selected", id);
2259        }
2260        for &id in not_selected {
2261            assert!(
2262                !mask.selected(id),
2263                "Expected row id {} to NOT be selected",
2264                id
2265            );
2266        }
2267    }
2268
2269    #[test]
2270    fn test_row_id_mask_construction() {
2271        let full_mask = RowIdMask::all_rows();
2272        assert_eq!(full_mask.max_len(), None);
2273        assert_row_id_mask_selects(&full_mask, &[0, 1, 100, u64::MAX - 1], &[]);
2274
2275        let empty_mask = RowIdMask::allow_nothing();
2276        assert_eq!(empty_mask.max_len(), Some(0));
2277        assert_row_id_mask_selects(&empty_mask, &[], &[0, 1, 100]);
2278
2279        let allow_list = RowIdMask::from_allowed(row_ids(&[10, 20, 30]));
2280        assert_eq!(allow_list.max_len(), Some(3));
2281        assert_row_id_mask_selects(&allow_list, &[10, 20, 30], &[0, 15, 25, 40]);
2282
2283        let block_list = RowIdMask::from_block(row_ids(&[10, 20, 30]));
2284        assert_eq!(block_list.max_len(), None);
2285        assert_row_id_mask_selects(&block_list, &[0, 15, 25, 40], &[10, 20, 30]);
2286    }
2287
2288    #[test]
2289    fn test_row_id_mask_selected_indices() {
2290        // Allow list
2291        let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 40]));
2292        assert!(mask.selected_indices(std::iter::empty()).is_empty());
2293        assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[1, 3]);
2294
2295        // Block list
2296        let mask = RowIdMask::from_block(row_ids(&[10, 20, 40]));
2297        assert!(mask.selected_indices(std::iter::empty()).is_empty());
2298        assert_eq!(mask.selected_indices([25, 20, 14, 10].iter()), &[0, 2]);
2299    }
2300
2301    #[test]
2302    fn test_row_id_mask_also_allow() {
2303        // Allow list
2304        let mask = RowIdMask::from_allowed(row_ids(&[10, 20]));
2305        let new_mask = mask.also_allow(row_ids(&[20, 30, 40]));
2306        assert_eq!(
2307            new_mask,
2308            RowIdMask::from_allowed(row_ids(&[10, 20, 30, 40]))
2309        );
2310
2311        // Block list
2312        let mask = RowIdMask::from_block(row_ids(&[10, 20, 30]));
2313        let new_mask = mask.also_allow(row_ids(&[20, 40]));
2314        assert_eq!(new_mask, RowIdMask::from_block(row_ids(&[10, 30])));
2315    }
2316
2317    #[test]
2318    fn test_row_id_mask_also_block() {
2319        // Allow list
2320        let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 30]));
2321        let new_mask = mask.also_block(row_ids(&[20, 40]));
2322        assert_eq!(new_mask, RowIdMask::from_allowed(row_ids(&[10, 30])));
2323
2324        // Block list
2325        let mask = RowIdMask::from_block(row_ids(&[10, 20]));
2326        let new_mask = mask.also_block(row_ids(&[20, 30, 40]));
2327        assert_eq!(new_mask, RowIdMask::from_block(row_ids(&[10, 20, 30, 40])));
2328    }
2329
2330    #[test]
2331    fn test_row_id_mask_iter_ids() {
2332        // Allow list
2333        let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 30]));
2334        let ids: Vec<u64> = mask.iter_ids().unwrap().collect();
2335        assert_eq!(ids, vec![10, 20, 30]);
2336
2337        // Empty allow list
2338        let mask = RowIdMask::allow_nothing();
2339        let iter = mask.iter_ids();
2340        assert!(iter.is_some());
2341        assert_eq!(iter.unwrap().count(), 0);
2342
2343        // Block list
2344        let mask = RowIdMask::from_block(row_ids(&[10, 20, 30]));
2345        assert!(mask.iter_ids().is_none());
2346    }
2347
2348    #[test]
2349    fn test_row_id_mask_default() {
2350        let mask = RowIdMask::default();
2351        // Default should be BlockList with empty set (all rows allowed)
2352        assert_row_id_mask_selects(&mask, &[0, 1, 100, 1000], &[]);
2353        assert_eq!(mask.max_len(), None);
2354    }
2355
2356    #[test]
2357    fn test_row_id_mask_ops() {
2358        let mask = RowIdMask::default();
2359        assert_row_id_mask_selects(&mask, &[1, 5, 100], &[]);
2360
2361        let block_list = mask.also_block(row_ids(&[0, 5, 15]));
2362        assert_row_id_mask_selects(&block_list, &[1, 100], &[5]);
2363
2364        let allow_list = RowIdMask::from_allowed(row_ids(&[0, 2, 5]));
2365        assert_row_id_mask_selects(&allow_list, &[5], &[1, 100]);
2366    }
2367
2368    #[test]
2369    fn test_row_id_mask_combined_ops() {
2370        // Test combining allow and block operations
2371        let mask = RowIdMask::from_allowed(row_ids(&[10, 20, 30, 40, 50]));
2372        let mask = mask.also_block(row_ids(&[20, 40]));
2373        assert_row_id_mask_selects(&mask, &[10, 30, 50], &[20, 40]);
2374
2375        let mask = mask.also_allow(row_ids(&[20, 60]));
2376        assert_row_id_mask_selects(&mask, &[10, 20, 30, 50, 60], &[40]);
2377    }
2378
2379    #[test]
2380    fn test_row_id_mask_with_large_values() {
2381        let large_ids = [u64::MAX - 10, u64::MAX - 5, u64::MAX - 1];
2382
2383        // Allow list with large values
2384        let mask = RowIdMask::from_allowed(row_ids(&large_ids));
2385        for &id in &large_ids {
2386            assert!(mask.selected(id));
2387        }
2388        assert!(!mask.selected(u64::MAX));
2389        assert!(!mask.selected(0));
2390
2391        // Block list with large values
2392        let mask = RowIdMask::from_block(row_ids(&large_ids));
2393        for &id in &large_ids {
2394            assert!(!mask.selected(id));
2395        }
2396        assert!(mask.selected(u64::MAX));
2397        assert!(mask.selected(0));
2398    }
2399
2400    proptest::proptest! {
2401        #[test]
2402        fn test_row_id_set_from_sorted_iter_proptest(
2403            mut row_ids in proptest::collection::vec(0..u64::MAX, 0..1000)
2404        ) {
2405            row_ids.sort();
2406            row_ids.dedup();
2407            let num_rows = row_ids.len();
2408            let set = RowIdSet::from_sorted_iter(row_ids.clone()).unwrap();
2409            prop_assert_eq!(set.len(), Some(num_rows as u64));
2410            for id in row_ids {
2411                prop_assert!(set.contains(id));
2412            }
2413        }
2414
2415        #[test]
2416        fn test_row_id_set_union_proptest(
2417            ids1 in proptest::collection::vec(0..u64::MAX, 0..500),
2418            ids2 in proptest::collection::vec(0..u64::MAX, 0..500),
2419        ) {
2420            let set1 = row_ids(&ids1);
2421            let set2 = row_ids(&ids2);
2422
2423            let result = set1.union(&set2);
2424
2425            // All ids from both sets should be in result
2426            for id in ids1.iter().chain(ids2.iter()) {
2427                prop_assert!(result.contains(*id));
2428            }
2429
2430            // Result size should be union size
2431            let expected_size = ids1.iter().chain(ids2.iter()).collect::<std::collections::HashSet<_>>().len();
2432            prop_assert_eq!(result.len(), Some(expected_size as u64));
2433        }
2434
2435        #[test]
2436        fn test_row_id_set_difference_proptest(
2437            ids1 in proptest::collection::vec(0..u64::MAX, 0..500),
2438            ids2 in proptest::collection::vec(0..u64::MAX, 0..500),
2439        ) {
2440            let set1 = row_ids(&ids1);
2441            let set2 = row_ids(&ids2);
2442
2443            let result = set1.difference(&set2);
2444
2445            // Items in ids1 but not in ids2 should be in result
2446            for id in &ids1 {
2447                if !ids2.contains(id) {
2448                    prop_assert!(result.contains(*id));
2449                } else {
2450                    prop_assert!(!result.contains(*id));
2451                }
2452            }
2453        }
2454
2455        #[test]
2456        fn test_row_id_mask_allow_block_proptest(
2457            allow_ids in proptest::collection::vec(0..10000u64, 0..100),
2458            block_ids in proptest::collection::vec(0..10000u64, 0..100),
2459            test_ids in proptest::collection::vec(0..10000u64, 0..50),
2460        ) {
2461            let mask = RowIdMask::from_allowed(row_ids(&allow_ids))
2462                .also_block(row_ids(&block_ids));
2463
2464            for id in test_ids {
2465                let expected = allow_ids.contains(&id) && !block_ids.contains(&id);
2466                prop_assert_eq!(mask.selected(id), expected);
2467            }
2468        }
2469    }
2470}