polars_arrow/bitmap/utils/
iterator.rs

1use polars_utils::slice::load_padded_le_u64;
2
3use super::get_bit_unchecked;
4use crate::bitmap::MutableBitmap;
5use crate::trusted_len::TrustedLen;
6
7/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit),
8/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`.
9#[derive(Debug, Clone)]
10pub struct BitmapIter<'a> {
11    bytes: &'a [u8],
12    word: u64,
13    word_len: usize,
14    rest_len: usize,
15}
16
17impl<'a> BitmapIter<'a> {
18    /// Creates a new [`BitmapIter`].
19    pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
20        if len == 0 {
21            return Self {
22                bytes,
23                word: 0,
24                word_len: 0,
25                rest_len: 0,
26            };
27        }
28
29        assert!(bytes.len() * 8 >= offset + len);
30        let first_byte_idx = offset / 8;
31        let bytes = &bytes[first_byte_idx..];
32        let offset = offset % 8;
33
34        // Make sure during our hot loop all our loads are full 8-byte loads
35        // by loading the remainder now if it exists.
36        let word = load_padded_le_u64(bytes) >> offset;
37        let mod8 = bytes.len() % 8;
38        let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };
39        let bytes = &bytes[first_word_bytes..];
40
41        let word_len = (first_word_bytes * 8 - offset).min(len);
42        let rest_len = len - word_len;
43        Self {
44            bytes,
45            word,
46            word_len,
47            rest_len,
48        }
49    }
50
51    /// Consume and returns the numbers of `1` / `true` values at the beginning of the iterator.
52    ///
53    /// This performs the same operation as `(&mut iter).take_while(|b| b).count()`.
54    ///
55    /// This is a lot more efficient than consecutively polling the iterator and should therefore
56    /// be preferred, if the use-case allows for it.
57    pub fn take_leading_ones(&mut self) -> usize {
58        let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
59        self.word_len -= word_ones;
60        self.word = self.word.wrapping_shr(word_ones as u32);
61
62        if self.word_len != 0 {
63            return word_ones;
64        }
65
66        let mut num_leading_ones = word_ones;
67
68        while self.rest_len != 0 {
69            self.word_len = usize::min(self.rest_len, 64);
70            self.rest_len -= self.word_len;
71
72            unsafe {
73                let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
74                self.word = u64::from_le_bytes(chunk);
75                self.bytes = self.bytes.get_unchecked(8..);
76            }
77
78            let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
79            self.word_len -= word_ones;
80            self.word = self.word.wrapping_shr(word_ones as u32);
81            num_leading_ones += word_ones;
82
83            if self.word_len != 0 {
84                return num_leading_ones;
85            }
86        }
87
88        num_leading_ones
89    }
90
91    /// Consume and returns the numbers of `0` / `false` values that the start of the iterator.
92    ///
93    /// This performs the same operation as `(&mut iter).take_while(|b| !b).count()`.
94    ///
95    /// This is a lot more efficient than consecutively polling the iterator and should therefore
96    /// be preferred, if the use-case allows for it.
97    pub fn take_leading_zeros(&mut self) -> usize {
98        let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
99        self.word_len -= word_zeros;
100        self.word = self.word.wrapping_shr(word_zeros as u32);
101
102        if self.word_len != 0 {
103            return word_zeros;
104        }
105
106        let mut num_leading_zeros = word_zeros;
107
108        while self.rest_len != 0 {
109            self.word_len = usize::min(self.rest_len, 64);
110            self.rest_len -= self.word_len;
111            unsafe {
112                let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
113                self.word = u64::from_le_bytes(chunk);
114                self.bytes = self.bytes.get_unchecked(8..);
115            }
116
117            let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
118            self.word_len -= word_zeros;
119            self.word = self.word.wrapping_shr(word_zeros as u32);
120            num_leading_zeros += word_zeros;
121
122            if self.word_len != 0 {
123                return num_leading_zeros;
124            }
125        }
126
127        num_leading_zeros
128    }
129
130    /// Returns the number of remaining elements in the iterator
131    #[inline]
132    pub fn num_remaining(&self) -> usize {
133        self.word_len + self.rest_len
134    }
135
136    /// Collect at most `n` elements from this iterator into `bitmap`
137    pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {
138        fn collect_word(
139            word: &mut u64,
140            word_len: &mut usize,
141            bitmap: &mut MutableBitmap,
142            n: &mut usize,
143        ) {
144            while *n > 0 && *word_len > 0 {
145                {
146                    let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);
147                    let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);
148                    *word = word.wrapping_shr(shift);
149                    *word_len -= shift as usize;
150                    *n -= shift as usize;
151
152                    bitmap.extend_constant(shift as usize, true);
153                }
154
155                {
156                    let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);
157                    let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);
158                    *word = word.wrapping_shr(shift);
159                    *word_len -= shift as usize;
160                    *n -= shift as usize;
161
162                    bitmap.extend_constant(shift as usize, false);
163                }
164            }
165        }
166
167        let mut n = usize::min(n, self.num_remaining());
168        bitmap.reserve(n);
169
170        collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
171
172        if n == 0 {
173            return;
174        }
175
176        let num_words = n / 64;
177
178        if num_words > 0 {
179            assert!(self.bytes.len() >= num_words * size_of::<u64>());
180
181            bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);
182
183            self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };
184            self.rest_len -= num_words * u64::BITS as usize;
185            n -= num_words * u64::BITS as usize;
186        }
187
188        if n == 0 {
189            return;
190        }
191
192        assert!(self.bytes.len() >= size_of::<u64>());
193
194        self.word_len = usize::min(self.rest_len, 64);
195        self.rest_len -= self.word_len;
196        unsafe {
197            let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
198            self.word = u64::from_le_bytes(chunk);
199            self.bytes = self.bytes.get_unchecked(8..);
200        }
201
202        collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
203
204        debug_assert!(self.num_remaining() == 0 || n == 0);
205    }
206}
207
208impl Iterator for BitmapIter<'_> {
209    type Item = bool;
210
211    #[inline]
212    fn next(&mut self) -> Option<Self::Item> {
213        if self.word_len == 0 {
214            if self.rest_len == 0 {
215                return None;
216            }
217
218            self.word_len = self.rest_len.min(64);
219            self.rest_len -= self.word_len;
220
221            unsafe {
222                let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
223                self.word = u64::from_le_bytes(chunk);
224                self.bytes = self.bytes.get_unchecked(8..);
225            }
226        }
227
228        let ret = self.word & 1 != 0;
229        self.word >>= 1;
230        self.word_len -= 1;
231        Some(ret)
232    }
233
234    #[inline]
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        let num_remaining = self.num_remaining();
237        (num_remaining, Some(num_remaining))
238    }
239
240    #[inline]
241    fn nth(&mut self, mut n: usize) -> Option<Self::Item> {
242        if n >= self.word_len + self.rest_len {
243            self.word = 0;
244            self.word_len = 0;
245            self.rest_len = 0;
246            return None;
247        }
248
249        // Advance words in buffer, skip words as needed
250        if n >= self.word_len {
251            n -= self.word_len;
252
253            let word_offset = n / 64;
254            n -= word_offset * 64;
255            self.rest_len -= word_offset * 64;
256
257            self.word_len = self.rest_len.min(64);
258            self.rest_len -= self.word_len;
259
260            let byte_offset = 8 * word_offset;
261
262            // Safety: bytes is large enough at construction time.
263            debug_assert!(byte_offset + 8 <= self.bytes.len());
264            unsafe {
265                let chunk = self
266                    .bytes
267                    .get_unchecked(byte_offset..byte_offset + 8)
268                    .try_into()
269                    .unwrap();
270                self.word = u64::from_le_bytes(chunk);
271                self.bytes = self.bytes.get_unchecked(byte_offset + 8..);
272            }
273        }
274
275        // At this point, n < self.word_len
276        debug_assert!(self.word_len > n);
277
278        // Advance index by n and take value at final index
279        self.word >>= n;
280        self.word_len -= n;
281
282        let ret = self.word & 1 != 0;
283        self.word >>= 1;
284        self.word_len -= 1;
285        Some(ret)
286    }
287}
288
289impl DoubleEndedIterator for BitmapIter<'_> {
290    #[inline]
291    fn next_back(&mut self) -> Option<bool> {
292        if self.rest_len > 0 {
293            self.rest_len -= 1;
294            Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })
295        } else if self.word_len > 0 {
296            self.word_len -= 1;
297            Some(self.word & (1 << self.word_len) != 0)
298        } else {
299            None
300        }
301    }
302}
303
304unsafe impl TrustedLen for BitmapIter<'_> {}
305impl ExactSizeIterator for BitmapIter<'_> {}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_collect_into_17579() {
313        let mut bitmap = MutableBitmap::with_capacity(64);
314        BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)
315            .collect_n_into(&mut bitmap, 129);
316
317        let bitmap = bitmap.freeze();
318
319        assert_eq!(bitmap.set_bits(), 4);
320    }
321
322    #[test]
323    #[ignore = "Fuzz test. Too slow"]
324    fn test_fuzz_collect_into() {
325        for _ in 0..10_000 {
326            let mut set_bits = 0;
327            let mut unset_bits = 0;
328
329            let mut length = 0;
330            let mut pattern = Vec::new();
331            for _ in 0..rand::random::<u64>() % 1024 {
332                let bs = rand::random::<u8>() % 4;
333
334                let word = match bs {
335                    0 => u64::MIN,
336                    1 => u64::MAX,
337                    2 | 3 => rand::random(),
338                    _ => unreachable!(),
339                };
340
341                pattern.extend_from_slice(&word.to_le_bytes());
342                set_bits += word.count_ones();
343                unset_bits += word.count_zeros();
344                length += 64;
345            }
346
347            for _ in 0..rand::random::<u64>() % 7 {
348                let b = rand::random::<u8>();
349                pattern.push(b);
350                set_bits += b.count_ones();
351                unset_bits += b.count_zeros();
352                length += 8;
353            }
354
355            let last_length = rand::random::<u64>() % 8;
356            if last_length != 0 {
357                let b = rand::random::<u8>();
358                pattern.push(b);
359                let ones = (b & ((1 << last_length) - 1)).count_ones();
360                set_bits += ones;
361                unset_bits += last_length as u32 - ones;
362                length += last_length;
363            }
364
365            let mut iter = BitmapIter::new(&pattern, 0, length as usize);
366            let mut bitmap = MutableBitmap::with_capacity(length as usize);
367
368            while iter.num_remaining() > 0 {
369                let len_before = bitmap.len();
370                let n = rand::random::<u64>() as usize % iter.num_remaining();
371                iter.collect_n_into(&mut bitmap, n);
372
373                // Ensure we are booking the progress we expect
374                assert_eq!(bitmap.len(), len_before + n);
375            }
376
377            let bitmap = bitmap.freeze();
378
379            assert_eq!(bitmap.set_bits(), set_bits as usize);
380            assert_eq!(bitmap.unset_bits(), unset_bits as usize);
381        }
382    }
383
384    #[test]
385    #[ignore = "Fuzz test. Too slow"]
386    fn test_fuzz_leading_ops() {
387        for _ in 0..10_000 {
388            let mut length = 0;
389            let mut pattern = Vec::new();
390            for _ in 0..rand::random::<u64>() % 1024 {
391                let bs = rand::random::<u8>() % 4;
392
393                let word = match bs {
394                    0 => u64::MIN,
395                    1 => u64::MAX,
396                    2 | 3 => rand::random(),
397                    _ => unreachable!(),
398                };
399
400                pattern.extend_from_slice(&word.to_le_bytes());
401                length += 64;
402            }
403
404            for _ in 0..rand::random::<u64>() % 7 {
405                pattern.push(rand::random::<u8>());
406                length += 8;
407            }
408
409            let last_length = rand::random::<u64>() % 8;
410            if last_length != 0 {
411                pattern.push(rand::random::<u8>());
412                length += last_length;
413            }
414
415            let mut iter = BitmapIter::new(&pattern, 0, length as usize);
416
417            let mut prev_remaining = iter.num_remaining();
418            while iter.num_remaining() != 0 {
419                let num_ones = iter.clone().take_leading_ones();
420                assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());
421
422                let num_zeros = iter.clone().take_leading_zeros();
423                assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());
424
425                // Ensure that we are making progress
426                assert!(iter.num_remaining() < prev_remaining);
427                prev_remaining = iter.num_remaining();
428            }
429
430            assert_eq!(iter.take_leading_zeros(), 0);
431            assert_eq!(iter.take_leading_ones(), 0);
432        }
433    }
434
435    #[test]
436    #[allow(clippy::iter_nth_zero)]
437    fn test_bitmap_iter_nth() {
438        // Calling nth repeatedly advances through the bitmap
439        {
440            let mut iter = BitmapIter::new(&[0b10110001], 0, 8);
441            assert_eq!(iter.nth(0), Some(true));
442            assert_eq!(iter.nth(0), Some(false));
443            assert_eq!(iter.nth(2), Some(true));
444            assert_eq!(iter.nth(3), None);
445
446            assert_eq!(iter.next(), None);
447        }
448
449        // Test parity with next()-based implementation on of singular call to nth()
450        for len in [0, 1, 2, 63, 64, 65, 127, 128, 129] {
451            for offset in [0, 1, 2] {
452                // binary '01010101' == 85
453                let iter = BitmapIter::new(
454                    &[
455                        0, 1, 2, 4, 8, 16, 32, 64, 85, 170, 85, 170, 85, 170, 85, 170, 255, 0,
456                    ],
457                    offset,
458                    len,
459                );
460
461                for i in 0..=len {
462                    let mut iter_expected = iter.clone();
463                    let mut iter_test = iter.clone();
464
465                    let prev_rest_len = iter_test.rest_len;
466                    let prev_word_len = iter_test.word_len;
467
468                    assert_eq!(len, prev_rest_len + prev_word_len);
469
470                    // Iterate.
471                    let out = iter_test.nth(i);
472                    for _ in 0..i {
473                        iter_expected.next();
474                    }
475                    let expected = iter_expected.next();
476
477                    // Check value.
478                    assert_eq!(out, expected);
479
480                    // Check internal sate.
481                    let final_rest_len = iter_test.rest_len;
482                    let final_word_len = iter_test.word_len;
483                    match out {
484                        Some(_) => assert_eq!(
485                            prev_rest_len + prev_word_len,
486                            i + 1 + final_rest_len + final_word_len
487                        ),
488                        None => {
489                            assert!(i >= prev_rest_len + prev_word_len);
490                            assert_eq!(final_rest_len + final_word_len, 0)
491                        },
492                    };
493                }
494            }
495        }
496
497        // Check internal state on repeat calls to nth().
498        {
499            for len in [0, 63, 64, 65, 126, 128, 129] {
500                let mut iter =
501                    BitmapIter::new(&[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0], 0, len);
502                for step in [0, 1, 2, 3] {
503                    for i in (0..len + step + 1).step_by(step + 1) {
504                        let prev_rest_len = iter.rest_len;
505                        let prev_word_len = iter.word_len;
506
507                        let out = iter.nth(step);
508
509                        let final_rest_len = iter.rest_len;
510                        let final_word_len = iter.word_len;
511                        match out {
512                            Some(_) => assert_eq!(
513                                prev_rest_len + prev_word_len,
514                                step + 1 + final_rest_len + final_word_len
515                            ),
516                            None => {
517                                assert!(i >= prev_rest_len + prev_word_len);
518                                assert_eq!(final_rest_len + final_word_len, 0)
519                            },
520                        };
521                    }
522                }
523            }
524        }
525
526        // Edge cases
527        let mut iter = BitmapIter::new(&[], 0, 0);
528        assert_eq!(iter.nth(0), None);
529    }
530}