polars_arrow/bitmap/utils/
iterator.rs

1use polars_utils::slice::load_padded_le_u64;
2
3use super::get_bit_unchecked;
4use crate::bitmap::MutableBitmap;
5use crate::trusted_len::TrustedLen;
6
7/// An iterator over bits according to the [LSB](https://en.wikipedia.org/wiki/Bit_numbering#Least_significant_bit),
8/// i.e. the bytes `[4u8, 128u8]` correspond to `[false, false, true, false, ..., true]`.
9#[derive(Debug, Clone)]
10pub struct BitmapIter<'a> {
11    bytes: &'a [u8],
12    word: u64,
13    word_len: usize,
14    rest_len: usize,
15}
16
17impl<'a> BitmapIter<'a> {
18    /// Creates a new [`BitmapIter`].
19    pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
20        if len == 0 {
21            return Self {
22                bytes,
23                word: 0,
24                word_len: 0,
25                rest_len: 0,
26            };
27        }
28
29        assert!(bytes.len() * 8 >= offset + len);
30        let first_byte_idx = offset / 8;
31        let bytes = &bytes[first_byte_idx..];
32        let offset = offset % 8;
33
34        // Make sure during our hot loop all our loads are full 8-byte loads
35        // by loading the remainder now if it exists.
36        let word = load_padded_le_u64(bytes) >> offset;
37        let mod8 = bytes.len() % 8;
38        let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };
39        let bytes = &bytes[first_word_bytes..];
40
41        let word_len = (first_word_bytes * 8 - offset).min(len);
42        let rest_len = len - word_len;
43        Self {
44            bytes,
45            word,
46            word_len,
47            rest_len,
48        }
49    }
50
51    /// Consume and returns the numbers of `1` / `true` values at the beginning of the iterator.
52    ///
53    /// This performs the same operation as `(&mut iter).take_while(|b| b).count()`.
54    ///
55    /// This is a lot more efficient than consecutively polling the iterator and should therefore
56    /// be preferred, if the use-case allows for it.
57    pub fn take_leading_ones(&mut self) -> usize {
58        let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
59        self.word_len -= word_ones;
60        self.word = self.word.wrapping_shr(word_ones as u32);
61
62        if self.word_len != 0 {
63            return word_ones;
64        }
65
66        let mut num_leading_ones = word_ones;
67
68        while self.rest_len != 0 {
69            self.word_len = usize::min(self.rest_len, 64);
70            self.rest_len -= self.word_len;
71
72            unsafe {
73                let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
74                self.word = u64::from_le_bytes(chunk);
75                self.bytes = self.bytes.get_unchecked(8..);
76            }
77
78            let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
79            self.word_len -= word_ones;
80            self.word = self.word.wrapping_shr(word_ones as u32);
81            num_leading_ones += word_ones;
82
83            if self.word_len != 0 {
84                return num_leading_ones;
85            }
86        }
87
88        num_leading_ones
89    }
90
91    /// Consume and returns the numbers of `0` / `false` values that the start of the iterator.
92    ///
93    /// This performs the same operation as `(&mut iter).take_while(|b| !b).count()`.
94    ///
95    /// This is a lot more efficient than consecutively polling the iterator and should therefore
96    /// be preferred, if the use-case allows for it.
97    pub fn take_leading_zeros(&mut self) -> usize {
98        let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
99        self.word_len -= word_zeros;
100        self.word = self.word.wrapping_shr(word_zeros as u32);
101
102        if self.word_len != 0 {
103            return word_zeros;
104        }
105
106        let mut num_leading_zeros = word_zeros;
107
108        while self.rest_len != 0 {
109            self.word_len = usize::min(self.rest_len, 64);
110            self.rest_len -= self.word_len;
111            unsafe {
112                let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
113                self.word = u64::from_le_bytes(chunk);
114                self.bytes = self.bytes.get_unchecked(8..);
115            }
116
117            let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
118            self.word_len -= word_zeros;
119            self.word = self.word.wrapping_shr(word_zeros as u32);
120            num_leading_zeros += word_zeros;
121
122            if self.word_len != 0 {
123                return num_leading_zeros;
124            }
125        }
126
127        num_leading_zeros
128    }
129
130    /// Returns the number of remaining elements in the iterator
131    #[inline]
132    pub fn num_remaining(&self) -> usize {
133        self.word_len + self.rest_len
134    }
135
136    /// Collect at most `n` elements from this iterator into `bitmap`
137    pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {
138        fn collect_word(
139            word: &mut u64,
140            word_len: &mut usize,
141            bitmap: &mut MutableBitmap,
142            n: &mut usize,
143        ) {
144            while *n > 0 && *word_len > 0 {
145                {
146                    let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);
147                    let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);
148                    *word = word.wrapping_shr(shift);
149                    *word_len -= shift as usize;
150                    *n -= shift as usize;
151
152                    bitmap.extend_constant(shift as usize, true);
153                }
154
155                {
156                    let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);
157                    let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);
158                    *word = word.wrapping_shr(shift);
159                    *word_len -= shift as usize;
160                    *n -= shift as usize;
161
162                    bitmap.extend_constant(shift as usize, false);
163                }
164            }
165        }
166
167        let mut n = usize::min(n, self.num_remaining());
168        bitmap.reserve(n);
169
170        collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
171
172        if n == 0 {
173            return;
174        }
175
176        let num_words = n / 64;
177
178        if num_words > 0 {
179            assert!(self.bytes.len() >= num_words * size_of::<u64>());
180
181            bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);
182
183            self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };
184            self.rest_len -= num_words * u64::BITS as usize;
185            n -= num_words * u64::BITS as usize;
186        }
187
188        if n == 0 {
189            return;
190        }
191
192        assert!(self.bytes.len() >= size_of::<u64>());
193
194        self.word_len = usize::min(self.rest_len, 64);
195        self.rest_len -= self.word_len;
196        unsafe {
197            let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
198            self.word = u64::from_le_bytes(chunk);
199            self.bytes = self.bytes.get_unchecked(8..);
200        }
201
202        collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
203
204        debug_assert!(self.num_remaining() == 0 || n == 0);
205    }
206}
207
208impl Iterator for BitmapIter<'_> {
209    type Item = bool;
210
211    #[inline]
212    fn next(&mut self) -> Option<Self::Item> {
213        if self.word_len == 0 {
214            if self.rest_len == 0 {
215                return None;
216            }
217
218            self.word_len = self.rest_len.min(64);
219            self.rest_len -= self.word_len;
220
221            unsafe {
222                let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
223                self.word = u64::from_le_bytes(chunk);
224                self.bytes = self.bytes.get_unchecked(8..);
225            }
226        }
227
228        let ret = self.word & 1 != 0;
229        self.word >>= 1;
230        self.word_len -= 1;
231        Some(ret)
232    }
233
234    #[inline]
235    fn size_hint(&self) -> (usize, Option<usize>) {
236        let num_remaining = self.num_remaining();
237        (num_remaining, Some(num_remaining))
238    }
239}
240
241impl DoubleEndedIterator for BitmapIter<'_> {
242    #[inline]
243    fn next_back(&mut self) -> Option<bool> {
244        if self.rest_len > 0 {
245            self.rest_len -= 1;
246            Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })
247        } else if self.word_len > 0 {
248            self.word_len -= 1;
249            Some(self.word & (1 << self.word_len) != 0)
250        } else {
251            None
252        }
253    }
254}
255
256unsafe impl TrustedLen for BitmapIter<'_> {}
257impl ExactSizeIterator for BitmapIter<'_> {}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_collect_into_17579() {
265        let mut bitmap = MutableBitmap::with_capacity(64);
266        BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)
267            .collect_n_into(&mut bitmap, 129);
268
269        let bitmap = bitmap.freeze();
270
271        assert_eq!(bitmap.set_bits(), 4);
272    }
273
274    #[test]
275    #[ignore = "Fuzz test. Too slow"]
276    fn test_fuzz_collect_into() {
277        for _ in 0..10_000 {
278            let mut set_bits = 0;
279            let mut unset_bits = 0;
280
281            let mut length = 0;
282            let mut pattern = Vec::new();
283            for _ in 0..rand::random::<usize>() % 1024 {
284                let bs = rand::random::<u8>() % 4;
285
286                let word = match bs {
287                    0 => u64::MIN,
288                    1 => u64::MAX,
289                    2 | 3 => rand::random(),
290                    _ => unreachable!(),
291                };
292
293                pattern.extend_from_slice(&word.to_le_bytes());
294                set_bits += word.count_ones();
295                unset_bits += word.count_zeros();
296                length += 64;
297            }
298
299            for _ in 0..rand::random::<usize>() % 7 {
300                let b = rand::random::<u8>();
301                pattern.push(b);
302                set_bits += b.count_ones();
303                unset_bits += b.count_zeros();
304                length += 8;
305            }
306
307            let last_length = rand::random::<usize>() % 8;
308            if last_length != 0 {
309                let b = rand::random::<u8>();
310                pattern.push(b);
311                let ones = (b & ((1 << last_length) - 1)).count_ones();
312                set_bits += ones;
313                unset_bits += last_length as u32 - ones;
314                length += last_length;
315            }
316
317            let mut iter = BitmapIter::new(&pattern, 0, length);
318            let mut bitmap = MutableBitmap::with_capacity(length);
319
320            while iter.num_remaining() > 0 {
321                let len_before = bitmap.len();
322                let n = rand::random::<usize>() % iter.num_remaining();
323                iter.collect_n_into(&mut bitmap, n);
324
325                // Ensure we are booking the progress we expect
326                assert_eq!(bitmap.len(), len_before + n);
327            }
328
329            let bitmap = bitmap.freeze();
330
331            assert_eq!(bitmap.set_bits(), set_bits as usize);
332            assert_eq!(bitmap.unset_bits(), unset_bits as usize);
333        }
334    }
335
336    #[test]
337    #[ignore = "Fuzz test. Too slow"]
338    fn test_fuzz_leading_ops() {
339        for _ in 0..10_000 {
340            let mut length = 0;
341            let mut pattern = Vec::new();
342            for _ in 0..rand::random::<usize>() % 1024 {
343                let bs = rand::random::<u8>() % 4;
344
345                let word = match bs {
346                    0 => u64::MIN,
347                    1 => u64::MAX,
348                    2 | 3 => rand::random(),
349                    _ => unreachable!(),
350                };
351
352                pattern.extend_from_slice(&word.to_le_bytes());
353                length += 64;
354            }
355
356            for _ in 0..rand::random::<usize>() % 7 {
357                pattern.push(rand::random::<u8>());
358                length += 8;
359            }
360
361            let last_length = rand::random::<usize>() % 8;
362            if last_length != 0 {
363                pattern.push(rand::random::<u8>());
364                length += last_length;
365            }
366
367            let mut iter = BitmapIter::new(&pattern, 0, length);
368
369            let mut prev_remaining = iter.num_remaining();
370            while iter.num_remaining() != 0 {
371                let num_ones = iter.clone().take_leading_ones();
372                assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());
373
374                let num_zeros = iter.clone().take_leading_zeros();
375                assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());
376
377                // Ensure that we are making progress
378                assert!(iter.num_remaining() < prev_remaining);
379                prev_remaining = iter.num_remaining();
380            }
381
382            assert_eq!(iter.take_leading_zeros(), 0);
383            assert_eq!(iter.take_leading_ones(), 0);
384        }
385    }
386}