polars_arrow/bitmap/
bitmask.rs

1#[cfg(feature = "simd")]
2use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4use polars_utils::slice::load_padded_le_u64;
5
6use super::iterator::FastU56BitmapIter;
7use super::utils::{BitmapIter, count_zeros, fmt};
8use crate::bitmap::Bitmap;
9
10/// Returns the nth set bit in w, if n+1 bits are set. The indexing is
11/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w.
12#[inline]
13pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
14    // If we have BMI2's PDEP available, we use it. It takes the lower order
15    // bits of the first argument and spreads it along its second argument
16    // where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h.
17    // We use this by setting the first argument to 1 << n, which means the
18    // first n-1 zero bits of it will spread to the first n-1 one bits of w,
19    // after which the one bit will exactly get copied to the nth one bit of w.
20    #[cfg(all(not(miri), target_feature = "bmi2"))]
21    {
22        if n >= 32 {
23            return None;
24        }
25
26        let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
27        if nth_set_bit == 0 {
28            return None;
29        }
30
31        Some(nth_set_bit.trailing_zeros())
32    }
33
34    #[cfg(any(miri, not(target_feature = "bmi2")))]
35    {
36        // Each block of 2/4/8/16 bits contains how many set bits there are in that block.
37        let set_per_2 = w - ((w >> 1) & 0x55555555);
38        let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
39        let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
40        let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
41        let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
42        if n >= set_per_32 {
43            return None;
44        }
45
46        let mut idx = 0;
47        let mut n = n;
48        let next16 = set_per_16 & 0xff;
49        if n >= next16 {
50            n -= next16;
51            idx += 16;
52        }
53        let next8 = (set_per_8 >> idx) & 0xff;
54        if n >= next8 {
55            n -= next8;
56            idx += 8;
57        }
58        let next4 = (set_per_4 >> idx) & 0b1111;
59        if n >= next4 {
60            n -= next4;
61            idx += 4;
62        }
63        let next2 = (set_per_2 >> idx) & 0b11;
64        if n >= next2 {
65            n -= next2;
66            idx += 2;
67        }
68        let next1 = (w >> idx) & 0b1;
69        if n >= next1 {
70            idx += 1;
71        }
72        Some(idx)
73    }
74}
75
76#[derive(Default, Clone)]
77pub struct BitMask<'a> {
78    bytes: &'a [u8],
79    offset: usize,
80    len: usize,
81}
82
83impl std::fmt::Debug for BitMask<'_> {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        let Self { bytes, offset, len } = self;
86        let offset_num_bytes = offset / 8;
87        let offset_in_byte = offset % 8;
88        fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
89    }
90}
91
92impl<'a> BitMask<'a> {
93    pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
94        let (bytes, offset, len) = bitmap.as_slice();
95        Self::new(bytes, offset, len)
96    }
97
98    pub fn inner(&self) -> (&[u8], usize, usize) {
99        (self.bytes, self.offset, self.len)
100    }
101
102    pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
103        // Check length so we can use unsafe access in our get.
104        assert!(bytes.len() * 8 >= len + offset);
105        Self { bytes, offset, len }
106    }
107
108    #[inline(always)]
109    pub fn len(&self) -> usize {
110        self.len
111    }
112
113    #[inline]
114    pub fn advance_by(&mut self, idx: usize) {
115        assert!(idx <= self.len);
116        self.offset += idx;
117        self.len -= idx;
118    }
119
120    #[inline]
121    pub fn split_at(&self, idx: usize) -> (Self, Self) {
122        assert!(idx <= self.len);
123        unsafe { self.split_at_unchecked(idx) }
124    }
125
126    /// # Safety
127    /// The index must be in-bounds.
128    #[inline]
129    pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
130        debug_assert!(idx <= self.len);
131        let left = Self { len: idx, ..*self };
132        let right = Self {
133            len: self.len - idx,
134            offset: self.offset + idx,
135            ..*self
136        };
137        (left, right)
138    }
139
140    #[inline]
141    pub fn sliced(&self, offset: usize, length: usize) -> Self {
142        assert!(offset.checked_add(length).unwrap() <= self.len);
143        unsafe { self.sliced_unchecked(offset, length) }
144    }
145
146    /// # Safety
147    /// The index must be in-bounds.
148    #[inline]
149    pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
150        if cfg!(debug_assertions) {
151            assert!(offset.checked_add(length).unwrap() <= self.len);
152        }
153
154        Self {
155            bytes: self.bytes,
156            offset: self.offset + offset,
157            len: length,
158        }
159    }
160
161    pub fn unset_bits(&self) -> usize {
162        count_zeros(self.bytes, self.offset, self.len)
163    }
164
165    pub fn set_bits(&self) -> usize {
166        self.len - self.unset_bits()
167    }
168
169    pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {
170        FastU56BitmapIter::new(self.bytes, self.offset, self.len)
171    }
172
173    #[cfg(feature = "simd")]
174    #[inline]
175    pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
176    where
177        T: MaskElement,
178        LaneCount<N>: SupportedLaneCount,
179    {
180        // We don't support 64-lane masks because then we couldn't load our
181        // bitwise mask as a u64 and then do the byteshift on it.
182
183        let lanes = LaneCount::<N>::BITMASK_LEN;
184        assert!(lanes < 64);
185
186        let start_byte_idx = (self.offset + idx) / 8;
187        let byte_shift = (self.offset + idx) % 8;
188        if idx + lanes <= self.len {
189            // SAFETY: fast path, we know this is completely in-bounds.
190            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
191            Mask::from_bitmask(mask >> byte_shift)
192        } else if idx < self.len {
193            // SAFETY: we know that at least the first byte is in-bounds.
194            // This is partially out of bounds, we have to do extra masking.
195            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
196            let num_out_of_bounds = idx + lanes - self.len;
197            let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
198            Mask::from_bitmask(shifted)
199        } else {
200            Mask::from_bitmask(0u64)
201        }
202    }
203
204    #[inline]
205    pub fn get_u32(&self, idx: usize) -> u32 {
206        let start_byte_idx = (self.offset + idx) / 8;
207        let byte_shift = (self.offset + idx) % 8;
208        if idx + 32 <= self.len {
209            // SAFETY: fast path, we know this is completely in-bounds.
210            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
211            (mask >> byte_shift) as u32
212        } else if idx < self.len {
213            // SAFETY: we know that at least the first byte is in-bounds.
214            // This is partially out of bounds, we have to do extra masking.
215            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
216            let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
217            ((mask >> byte_shift) as u32) & out_of_bounds_mask
218        } else {
219            0
220        }
221    }
222
223    /// Computes the index of the nth set bit after start.
224    ///
225    /// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the
226    /// first bit set (which can be 0 as well). The returned index is absolute,
227    /// not relative to start.
228    pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
229        while start < self.len {
230            let next_u32_mask = self.get_u32(start);
231            if next_u32_mask == u32::MAX {
232                // Happy fast path for dense non-null section.
233                if n < 32 {
234                    return Some(start + n);
235                }
236                n -= 32;
237            } else {
238                let ones = next_u32_mask.count_ones() as usize;
239                if n < ones {
240                    let idx = unsafe {
241                        // SAFETY: we know the nth bit is in the mask.
242                        nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
243                    };
244                    return Some(start + idx);
245                }
246                n -= ones;
247            }
248
249            start += 32;
250        }
251
252        None
253    }
254
255    /// Computes the index of the nth set bit before end, counting backwards.
256    ///
257    /// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of
258    /// the last bit set (which can be 0 as well). The returned index is
259    /// absolute (and starts at the beginning), not relative to end.
260    pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
261        while end > 0 {
262            // We want to find bits *before* end, so if end < 32 we must mask
263            // out the bits after the endth.
264            let (u32_mask_start, u32_mask_mask) = if end >= 32 {
265                (end - 32, u32::MAX)
266            } else {
267                (0, (1 << end) - 1)
268            };
269            let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
270            if next_u32_mask == u32::MAX {
271                // Happy fast path for dense non-null section.
272                if n < 32 {
273                    return Some(end - 1 - n);
274                }
275                n -= 32;
276            } else {
277                let ones = next_u32_mask.count_ones() as usize;
278                if n < ones {
279                    let rev_n = ones - 1 - n;
280                    let idx = unsafe {
281                        // SAFETY: we know the rev_nth bit is in the mask.
282                        nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
283                    };
284                    return Some(u32_mask_start + idx);
285                }
286                n -= ones;
287            }
288
289            end = u32_mask_start;
290        }
291
292        None
293    }
294
295    #[inline]
296    pub fn get(&self, idx: usize) -> bool {
297        let byte_idx = (self.offset + idx) / 8;
298        let byte_shift = (self.offset + idx) % 8;
299
300        if idx < self.len {
301            // SAFETY: we know this is in-bounds.
302            let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
303            (byte >> byte_shift) & 1 == 1
304        } else {
305            false
306        }
307    }
308
309    pub fn iter(&self) -> BitmapIter<'_> {
310        BitmapIter::new(self.bytes, self.offset, self.len)
311    }
312}
313
314#[cfg(test)]
315mod test {
316    use super::*;
317
318    fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option<u32> {
319        for i in 0..32 {
320            if w & (1 << i) != 0 {
321                if n == 0 {
322                    return Some(i);
323                }
324                n -= 1;
325                w ^= 1 << i;
326            }
327        }
328        None
329    }
330
331    #[test]
332    fn test_nth_set_bit_u32() {
333        for n in 0..256 {
334            assert_eq!(nth_set_bit_u32(0, n), None);
335        }
336
337        for i in 0..32 {
338            assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
339            assert_eq!(nth_set_bit_u32(1 << i, 1), None);
340        }
341
342        for i in 0..10000 {
343            let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
344            for i in 0..=32 {
345                assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i));
346            }
347        }
348    }
349}