polars_arrow/bitmap/
bitmask.rs

1#[cfg(feature = "simd")]
2use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4use polars_utils::slice::load_padded_le_u64;
5
6use super::iterator::FastU56BitmapIter;
7use super::utils::{self, BitChunk, BitChunks, BitmapIter, count_zeros, fmt};
8use crate::bitmap::Bitmap;
9
10/// Returns the nth set bit in w, if n+1 bits are set. The indexing is
11/// zero-based, nth_set_bit_u32(w, 0) returns the least significant set bit in w.
12#[inline]
13pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
14    // If we have BMI2's PDEP available, we use it. It takes the lower order
15    // bits of the first argument and spreads it along its second argument
16    // where those bits are 1. So PDEP(abcdefgh, 11001001) becomes ef00g00h.
17    // We use this by setting the first argument to 1 << n, which means the
18    // first n-1 zero bits of it will spread to the first n-1 one bits of w,
19    // after which the one bit will exactly get copied to the nth one bit of w.
20    #[cfg(all(not(miri), target_feature = "bmi2"))]
21    {
22        if n >= 32 {
23            return None;
24        }
25
26        let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
27        if nth_set_bit == 0 {
28            return None;
29        }
30
31        Some(nth_set_bit.trailing_zeros())
32    }
33
34    #[cfg(any(miri, not(target_feature = "bmi2")))]
35    {
36        // Each block of 2/4/8/16 bits contains how many set bits there are in that block.
37        let set_per_2 = w - ((w >> 1) & 0x55555555);
38        let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
39        let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
40        let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
41        let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
42
43        if n >= set_per_32 {
44            return None;
45        }
46
47        let mut idx = 0;
48        let mut n = n;
49
50        let next16 = set_per_16 & 0xff;
51        if n >= next16 {
52            n -= next16;
53            idx += 16;
54        }
55        let next8 = (set_per_8 >> idx) & 0xff;
56        if n >= next8 {
57            n -= next8;
58            idx += 8;
59        }
60        let next4 = (set_per_4 >> idx) & 0b1111;
61        if n >= next4 {
62            n -= next4;
63            idx += 4;
64        }
65        let next2 = (set_per_2 >> idx) & 0b11;
66        if n >= next2 {
67            n -= next2;
68            idx += 2;
69        }
70        let next1 = (w >> idx) & 0b1;
71        if n >= next1 {
72            idx += 1;
73        }
74        Some(idx)
75    }
76}
77
78#[inline]
79pub fn nth_set_bit_u64(w: u64, n: u64) -> Option<u64> {
80    #[cfg(all(not(miri), target_feature = "bmi2"))]
81    {
82        if n >= 64 {
83            return None;
84        }
85
86        let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u64(1 << n, w) };
87        if nth_set_bit == 0 {
88            return None;
89        }
90
91        Some(nth_set_bit.trailing_zeros().into())
92    }
93
94    #[cfg(any(miri, not(target_feature = "bmi2")))]
95    {
96        // Each block of 2/4/8/16/32 bits contains how many set bits there are in that block.
97        let set_per_2 = w - ((w >> 1) & 0x5555555555555555);
98        let set_per_4 = (set_per_2 & 0x3333333333333333) + ((set_per_2 >> 2) & 0x3333333333333333);
99        let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f0f0f0f0f;
100        let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff00ff00ff;
101        let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0x0000ffff0000ffff;
102        let set_per_64 = (set_per_32 + (set_per_32 >> 32)) & 0xffffffff;
103
104        if n >= set_per_64 {
105            return None;
106        }
107
108        let mut idx = 0;
109        let mut n = n;
110
111        let next32 = set_per_32 & 0xffff;
112        if n >= next32 {
113            n -= next32;
114            idx += 32;
115        }
116        let next16 = (set_per_16 >> idx) & 0xffff;
117        if n >= next16 {
118            n -= next16;
119            idx += 16;
120        }
121        let next8 = (set_per_8 >> idx) & 0xff;
122        if n >= next8 {
123            n -= next8;
124            idx += 8;
125        }
126        let next4 = (set_per_4 >> idx) & 0b1111;
127        if n >= next4 {
128            n -= next4;
129            idx += 4;
130        }
131        let next2 = (set_per_2 >> idx) & 0b11;
132        if n >= next2 {
133            n -= next2;
134            idx += 2;
135        }
136        let next1 = (w >> idx) & 0b1;
137        if n >= next1 {
138            idx += 1;
139        }
140        Some(idx)
141    }
142}
143
144#[derive(Default, Clone, Copy)]
145pub struct BitMask<'a> {
146    bytes: &'a [u8],
147    offset: usize,
148    len: usize,
149}
150
151impl std::fmt::Debug for BitMask<'_> {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        let Self { bytes, offset, len } = self;
154        let offset_num_bytes = offset / 8;
155        let offset_in_byte = offset % 8;
156        fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
157    }
158}
159
160impl<'a> BitMask<'a> {
161    pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
162        let (bytes, offset, len) = bitmap.as_slice();
163        Self::new(bytes, offset, len)
164    }
165
166    pub fn inner(&self) -> (&[u8], usize, usize) {
167        (self.bytes, self.offset, self.len)
168    }
169
170    pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
171        // Check length so we can use unsafe access in our get.
172        assert!(bytes.len() * 8 >= len + offset);
173        Self { bytes, offset, len }
174    }
175
176    #[inline(always)]
177    pub fn len(&self) -> usize {
178        self.len
179    }
180
181    #[inline]
182    pub fn advance_by(&mut self, idx: usize) {
183        assert!(idx <= self.len);
184        self.offset += idx;
185        self.len -= idx;
186    }
187
188    #[inline]
189    pub fn split_at(&self, idx: usize) -> (Self, Self) {
190        assert!(idx <= self.len);
191        unsafe { self.split_at_unchecked(idx) }
192    }
193
194    /// # Safety
195    /// The index must be in-bounds.
196    #[inline]
197    pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
198        debug_assert!(idx <= self.len);
199        let left = Self { len: idx, ..*self };
200        let right = Self {
201            len: self.len - idx,
202            offset: self.offset + idx,
203            ..*self
204        };
205        (left, right)
206    }
207
208    #[inline]
209    pub fn sliced(&self, offset: usize, length: usize) -> Self {
210        assert!(offset.checked_add(length).unwrap() <= self.len);
211        unsafe { self.sliced_unchecked(offset, length) }
212    }
213
214    /// # Safety
215    /// The index must be in-bounds.
216    #[inline]
217    pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
218        if cfg!(debug_assertions) {
219            assert!(offset.checked_add(length).unwrap() <= self.len);
220        }
221
222        Self {
223            bytes: self.bytes,
224            offset: self.offset + offset,
225            len: length,
226        }
227    }
228
229    pub fn unset_bits(&self) -> usize {
230        count_zeros(self.bytes, self.offset, self.len)
231    }
232
233    pub fn set_bits(&self) -> usize {
234        self.len - self.unset_bits()
235    }
236
237    pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {
238        FastU56BitmapIter::new(self.bytes, self.offset, self.len)
239    }
240
241    #[cfg(feature = "simd")]
242    #[inline]
243    pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
244    where
245        T: MaskElement,
246        LaneCount<N>: SupportedLaneCount,
247    {
248        // We don't support 64-lane masks because then we couldn't load our
249        // bitwise mask as a u64 and then do the byteshift on it.
250
251        let lanes = LaneCount::<N>::BITMASK_LEN;
252        assert!(lanes < 64);
253
254        let start_byte_idx = (self.offset + idx) / 8;
255        let byte_shift = (self.offset + idx) % 8;
256        if idx + lanes <= self.len {
257            // SAFETY: fast path, we know this is completely in-bounds.
258            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
259            Mask::from_bitmask(mask >> byte_shift)
260        } else if idx < self.len {
261            // SAFETY: we know that at least the first byte is in-bounds.
262            // This is partially out of bounds, we have to do extra masking.
263            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
264            let num_out_of_bounds = idx + lanes - self.len;
265            let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
266            Mask::from_bitmask(shifted)
267        } else {
268            Mask::from_bitmask(0u64)
269        }
270    }
271
272    #[inline]
273    pub fn get_u32(&self, idx: usize) -> u32 {
274        let start_byte_idx = (self.offset + idx) / 8;
275        let byte_shift = (self.offset + idx) % 8;
276        if idx + 32 <= self.len {
277            // SAFETY: fast path, we know this is completely in-bounds.
278            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
279            (mask >> byte_shift) as u32
280        } else if idx < self.len {
281            // SAFETY: we know that at least the first byte is in-bounds.
282            // This is partially out of bounds, we have to do extra masking.
283            let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
284            let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
285            ((mask >> byte_shift) as u32) & out_of_bounds_mask
286        } else {
287            0
288        }
289    }
290
291    /// Computes the index of the nth set bit after start.
292    ///
293    /// Both are zero-indexed, so `nth_set_bit_idx(0, 0)` finds the index of the
294    /// first bit set (which can be 0 as well). The returned index is absolute,
295    /// not relative to start.
296    pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
297        while start < self.len {
298            let next_u32_mask = self.get_u32(start);
299            if next_u32_mask == u32::MAX {
300                // Happy fast path for dense non-null section.
301                if n < 32 {
302                    return Some(start + n);
303                }
304                n -= 32;
305            } else {
306                let ones = next_u32_mask.count_ones() as usize;
307                if n < ones {
308                    let idx = unsafe {
309                        // SAFETY: we know the nth bit is in the mask.
310                        nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
311                    };
312                    return Some(start + idx);
313                }
314                n -= ones;
315            }
316
317            start += 32;
318        }
319
320        None
321    }
322
323    /// Computes the index of the nth set bit before end, counting backwards.
324    ///
325    /// Both are zero-indexed, so nth_set_bit_idx_rev(0, len) finds the index of
326    /// the last bit set (which can be 0 as well). The returned index is
327    /// absolute (and starts at the beginning), not relative to end.
328    pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
329        while end > 0 {
330            // We want to find bits *before* end, so if end < 32 we must mask
331            // out the bits after the endth.
332            let (u32_mask_start, u32_mask_mask) = if end >= 32 {
333                (end - 32, u32::MAX)
334            } else {
335                (0, (1 << end) - 1)
336            };
337            let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
338            if next_u32_mask == u32::MAX {
339                // Happy fast path for dense non-null section.
340                if n < 32 {
341                    return Some(end - 1 - n);
342                }
343                n -= 32;
344            } else {
345                let ones = next_u32_mask.count_ones() as usize;
346                if n < ones {
347                    let rev_n = ones - 1 - n;
348                    let idx = unsafe {
349                        // SAFETY: we know the rev_nth bit is in the mask.
350                        nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
351                    };
352                    return Some(u32_mask_start + idx);
353                }
354                n -= ones;
355            }
356
357            end = u32_mask_start;
358        }
359
360        None
361    }
362
363    #[inline]
364    pub fn get(&self, idx: usize) -> bool {
365        if idx < self.len {
366            // SAFETY: we know this is in-bounds.
367            unsafe { self.get_bit_unchecked(idx) }
368        } else {
369            false
370        }
371    }
372
373    #[inline]
374    /// Get a bit at a certain idx.
375    ///
376    /// # Safety
377    ///
378    /// `idx` should be smaller than `len`
379    pub unsafe fn get_bit_unchecked(&self, idx: usize) -> bool {
380        let byte_idx = (self.offset + idx) / 8;
381        let byte_shift = (self.offset + idx) % 8;
382
383        // SAFETY: we know this is in-bounds.
384        let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
385        (byte >> byte_shift) & 1 == 1
386    }
387
388    pub fn iter(self) -> BitmapIter<'a> {
389        BitmapIter::new(self.bytes, self.offset, self.len)
390    }
391
392    /// Returns the number of zero bits from the start before a one bit is seen
393    pub fn leading_zeros(self) -> usize {
394        utils::leading_zeros(self.bytes, self.offset, self.len)
395    }
396    /// Returns the number of one bits from the start before a zero bit is seen
397    pub fn leading_ones(self) -> usize {
398        utils::leading_ones(self.bytes, self.offset, self.len)
399    }
400    /// Returns the number of zero bits from the back before a one bit is seen
401    pub fn trailing_zeros(self) -> usize {
402        utils::trailing_zeros(self.bytes, self.offset, self.len)
403    }
404    /// Returns the number of one bits from the back before a zero bit is seen
405    pub fn trailing_ones(self) -> usize {
406        utils::trailing_ones(self.bytes, self.offset, self.len)
407    }
408
409    /// Checks whether two [`Bitmap`]s have shared set bits.
410    ///
411    /// This is an optimized version of `(self & other) != 0000..`.
412    pub fn intersects_with(self, other: Self) -> bool {
413        self.num_intersections_with(other) != 0
414    }
415
416    /// Calculates the number of shared set bits between two [`Bitmap`]s.
417    pub fn num_intersections_with(self, other: Self) -> usize {
418        super::num_intersections_with(self, other)
419    }
420
421    /// Returns an iterator over bits in bit chunks [`BitChunk`].
422    ///
423    /// This iterator is useful to operate over multiple bits via e.g. bitwise.
424    pub fn chunks<T: BitChunk>(self) -> BitChunks<'a, T> {
425        BitChunks::new(self.bytes, self.offset, self.len)
426    }
427}
428
429#[cfg(test)]
430mod test {
431    use super::*;
432
433    fn naive_nth_bit_set_u32(mut w: u32, mut n: u32) -> Option<u32> {
434        for i in 0..32 {
435            if w & (1 << i) != 0 {
436                if n == 0 {
437                    return Some(i);
438                }
439                n -= 1;
440                w ^= 1 << i;
441            }
442        }
443        None
444    }
445
446    fn naive_nth_bit_set_u64(mut w: u64, mut n: u64) -> Option<u64> {
447        for i in 0..64 {
448            if w & (1 << i) != 0 {
449                if n == 0 {
450                    return Some(i);
451                }
452                n -= 1;
453                w ^= 1 << i;
454            }
455        }
456        None
457    }
458
459    #[test]
460    fn test_nth_set_bit_u32() {
461        for n in 0..256 {
462            assert_eq!(nth_set_bit_u32(0, n), None);
463        }
464
465        for i in 0..32 {
466            assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
467            assert_eq!(nth_set_bit_u32(1 << i, 1), None);
468        }
469
470        for i in 0..10000 {
471            let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
472            for i in 0..=32 {
473                assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set_u32(rnd, i));
474            }
475        }
476    }
477
478    #[test]
479    fn test_nth_set_bit_u64() {
480        for n in 0..256 {
481            assert_eq!(nth_set_bit_u64(0, n), None);
482        }
483
484        for i in 0..64 {
485            assert_eq!(nth_set_bit_u64(1 << i, 0), Some(i));
486            assert_eq!(nth_set_bit_u64(1 << i, 1), None);
487        }
488
489        for i in 0..10000 {
490            let rnd = 0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32;
491            for i in 0..=64 {
492                assert_eq!(nth_set_bit_u64(rnd, i), naive_nth_bit_set_u64(rnd, i));
493            }
494        }
495    }
496}