1use polars_utils::slice::load_padded_le_u64;
2
3use super::get_bit_unchecked;
4use crate::bitmap::MutableBitmap;
5use crate::trusted_len::TrustedLen;
6
7#[derive(Debug, Clone)]
10pub struct BitmapIter<'a> {
11 bytes: &'a [u8],
12 word: u64,
13 word_len: usize,
14 rest_len: usize,
15}
16
17impl<'a> BitmapIter<'a> {
18 pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
20 if len == 0 {
21 return Self {
22 bytes,
23 word: 0,
24 word_len: 0,
25 rest_len: 0,
26 };
27 }
28
29 assert!(bytes.len() * 8 >= offset + len);
30 let first_byte_idx = offset / 8;
31 let bytes = &bytes[first_byte_idx..];
32 let offset = offset % 8;
33
34 let word = load_padded_le_u64(bytes) >> offset;
37 let mod8 = bytes.len() % 8;
38 let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };
39 let bytes = &bytes[first_word_bytes..];
40
41 let word_len = (first_word_bytes * 8 - offset).min(len);
42 let rest_len = len - word_len;
43 Self {
44 bytes,
45 word,
46 word_len,
47 rest_len,
48 }
49 }
50
51 pub fn take_leading_ones(&mut self) -> usize {
58 let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
59 self.word_len -= word_ones;
60 self.word = self.word.wrapping_shr(word_ones as u32);
61
62 if self.word_len != 0 {
63 return word_ones;
64 }
65
66 let mut num_leading_ones = word_ones;
67
68 while self.rest_len != 0 {
69 self.word_len = usize::min(self.rest_len, 64);
70 self.rest_len -= self.word_len;
71
72 unsafe {
73 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
74 self.word = u64::from_le_bytes(chunk);
75 self.bytes = self.bytes.get_unchecked(8..);
76 }
77
78 let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
79 self.word_len -= word_ones;
80 self.word = self.word.wrapping_shr(word_ones as u32);
81 num_leading_ones += word_ones;
82
83 if self.word_len != 0 {
84 return num_leading_ones;
85 }
86 }
87
88 num_leading_ones
89 }
90
91 pub fn take_leading_zeros(&mut self) -> usize {
98 let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
99 self.word_len -= word_zeros;
100 self.word = self.word.wrapping_shr(word_zeros as u32);
101
102 if self.word_len != 0 {
103 return word_zeros;
104 }
105
106 let mut num_leading_zeros = word_zeros;
107
108 while self.rest_len != 0 {
109 self.word_len = usize::min(self.rest_len, 64);
110 self.rest_len -= self.word_len;
111 unsafe {
112 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
113 self.word = u64::from_le_bytes(chunk);
114 self.bytes = self.bytes.get_unchecked(8..);
115 }
116
117 let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
118 self.word_len -= word_zeros;
119 self.word = self.word.wrapping_shr(word_zeros as u32);
120 num_leading_zeros += word_zeros;
121
122 if self.word_len != 0 {
123 return num_leading_zeros;
124 }
125 }
126
127 num_leading_zeros
128 }
129
130 #[inline]
132 pub fn num_remaining(&self) -> usize {
133 self.word_len + self.rest_len
134 }
135
136 pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {
138 fn collect_word(
139 word: &mut u64,
140 word_len: &mut usize,
141 bitmap: &mut MutableBitmap,
142 n: &mut usize,
143 ) {
144 while *n > 0 && *word_len > 0 {
145 {
146 let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);
147 let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);
148 *word = word.wrapping_shr(shift);
149 *word_len -= shift as usize;
150 *n -= shift as usize;
151
152 bitmap.extend_constant(shift as usize, true);
153 }
154
155 {
156 let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);
157 let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);
158 *word = word.wrapping_shr(shift);
159 *word_len -= shift as usize;
160 *n -= shift as usize;
161
162 bitmap.extend_constant(shift as usize, false);
163 }
164 }
165 }
166
167 let mut n = usize::min(n, self.num_remaining());
168 bitmap.reserve(n);
169
170 collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
171
172 if n == 0 {
173 return;
174 }
175
176 let num_words = n / 64;
177
178 if num_words > 0 {
179 assert!(self.bytes.len() >= num_words * size_of::<u64>());
180
181 bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);
182
183 self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };
184 self.rest_len -= num_words * u64::BITS as usize;
185 n -= num_words * u64::BITS as usize;
186 }
187
188 if n == 0 {
189 return;
190 }
191
192 assert!(self.bytes.len() >= size_of::<u64>());
193
194 self.word_len = usize::min(self.rest_len, 64);
195 self.rest_len -= self.word_len;
196 unsafe {
197 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
198 self.word = u64::from_le_bytes(chunk);
199 self.bytes = self.bytes.get_unchecked(8..);
200 }
201
202 collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
203
204 debug_assert!(self.num_remaining() == 0 || n == 0);
205 }
206}
207
208impl Iterator for BitmapIter<'_> {
209 type Item = bool;
210
211 #[inline]
212 fn next(&mut self) -> Option<Self::Item> {
213 if self.word_len == 0 {
214 if self.rest_len == 0 {
215 return None;
216 }
217
218 self.word_len = self.rest_len.min(64);
219 self.rest_len -= self.word_len;
220
221 unsafe {
222 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
223 self.word = u64::from_le_bytes(chunk);
224 self.bytes = self.bytes.get_unchecked(8..);
225 }
226 }
227
228 let ret = self.word & 1 != 0;
229 self.word >>= 1;
230 self.word_len -= 1;
231 Some(ret)
232 }
233
234 #[inline]
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 let num_remaining = self.num_remaining();
237 (num_remaining, Some(num_remaining))
238 }
239
240 #[inline]
241 fn nth(&mut self, mut n: usize) -> Option<Self::Item> {
242 if n >= self.word_len + self.rest_len {
243 self.word = 0;
244 self.word_len = 0;
245 self.rest_len = 0;
246 return None;
247 }
248
249 if n >= self.word_len {
251 n -= self.word_len;
252
253 let word_offset = n / 64;
254 n -= word_offset * 64;
255 self.rest_len -= word_offset * 64;
256
257 self.word_len = self.rest_len.min(64);
258 self.rest_len -= self.word_len;
259
260 let byte_offset = 8 * word_offset;
261
262 debug_assert!(byte_offset + 8 <= self.bytes.len());
264 unsafe {
265 let chunk = self
266 .bytes
267 .get_unchecked(byte_offset..byte_offset + 8)
268 .try_into()
269 .unwrap();
270 self.word = u64::from_le_bytes(chunk);
271 self.bytes = self.bytes.get_unchecked(byte_offset + 8..);
272 }
273 }
274
275 debug_assert!(self.word_len > n);
277
278 self.word >>= n;
280 self.word_len -= n;
281
282 let ret = self.word & 1 != 0;
283 self.word >>= 1;
284 self.word_len -= 1;
285 Some(ret)
286 }
287}
288
289impl DoubleEndedIterator for BitmapIter<'_> {
290 #[inline]
291 fn next_back(&mut self) -> Option<bool> {
292 if self.rest_len > 0 {
293 self.rest_len -= 1;
294 Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })
295 } else if self.word_len > 0 {
296 self.word_len -= 1;
297 Some(self.word & (1 << self.word_len) != 0)
298 } else {
299 None
300 }
301 }
302}
303
304unsafe impl TrustedLen for BitmapIter<'_> {}
305impl ExactSizeIterator for BitmapIter<'_> {}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_collect_into_17579() {
313 let mut bitmap = MutableBitmap::with_capacity(64);
314 BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)
315 .collect_n_into(&mut bitmap, 129);
316
317 let bitmap = bitmap.freeze();
318
319 assert_eq!(bitmap.set_bits(), 4);
320 }
321
322 #[test]
323 #[ignore = "Fuzz test. Too slow"]
324 fn test_fuzz_collect_into() {
325 for _ in 0..10_000 {
326 let mut set_bits = 0;
327 let mut unset_bits = 0;
328
329 let mut length = 0;
330 let mut pattern = Vec::new();
331 for _ in 0..rand::random::<u64>() % 1024 {
332 let bs = rand::random::<u8>() % 4;
333
334 let word = match bs {
335 0 => u64::MIN,
336 1 => u64::MAX,
337 2 | 3 => rand::random(),
338 _ => unreachable!(),
339 };
340
341 pattern.extend_from_slice(&word.to_le_bytes());
342 set_bits += word.count_ones();
343 unset_bits += word.count_zeros();
344 length += 64;
345 }
346
347 for _ in 0..rand::random::<u64>() % 7 {
348 let b = rand::random::<u8>();
349 pattern.push(b);
350 set_bits += b.count_ones();
351 unset_bits += b.count_zeros();
352 length += 8;
353 }
354
355 let last_length = rand::random::<u64>() % 8;
356 if last_length != 0 {
357 let b = rand::random::<u8>();
358 pattern.push(b);
359 let ones = (b & ((1 << last_length) - 1)).count_ones();
360 set_bits += ones;
361 unset_bits += last_length as u32 - ones;
362 length += last_length;
363 }
364
365 let mut iter = BitmapIter::new(&pattern, 0, length as usize);
366 let mut bitmap = MutableBitmap::with_capacity(length as usize);
367
368 while iter.num_remaining() > 0 {
369 let len_before = bitmap.len();
370 let n = rand::random::<u64>() as usize % iter.num_remaining();
371 iter.collect_n_into(&mut bitmap, n);
372
373 assert_eq!(bitmap.len(), len_before + n);
375 }
376
377 let bitmap = bitmap.freeze();
378
379 assert_eq!(bitmap.set_bits(), set_bits as usize);
380 assert_eq!(bitmap.unset_bits(), unset_bits as usize);
381 }
382 }
383
384 #[test]
385 #[ignore = "Fuzz test. Too slow"]
386 fn test_fuzz_leading_ops() {
387 for _ in 0..10_000 {
388 let mut length = 0;
389 let mut pattern = Vec::new();
390 for _ in 0..rand::random::<u64>() % 1024 {
391 let bs = rand::random::<u8>() % 4;
392
393 let word = match bs {
394 0 => u64::MIN,
395 1 => u64::MAX,
396 2 | 3 => rand::random(),
397 _ => unreachable!(),
398 };
399
400 pattern.extend_from_slice(&word.to_le_bytes());
401 length += 64;
402 }
403
404 for _ in 0..rand::random::<u64>() % 7 {
405 pattern.push(rand::random::<u8>());
406 length += 8;
407 }
408
409 let last_length = rand::random::<u64>() % 8;
410 if last_length != 0 {
411 pattern.push(rand::random::<u8>());
412 length += last_length;
413 }
414
415 let mut iter = BitmapIter::new(&pattern, 0, length as usize);
416
417 let mut prev_remaining = iter.num_remaining();
418 while iter.num_remaining() != 0 {
419 let num_ones = iter.clone().take_leading_ones();
420 assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());
421
422 let num_zeros = iter.clone().take_leading_zeros();
423 assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());
424
425 assert!(iter.num_remaining() < prev_remaining);
427 prev_remaining = iter.num_remaining();
428 }
429
430 assert_eq!(iter.take_leading_zeros(), 0);
431 assert_eq!(iter.take_leading_ones(), 0);
432 }
433 }
434
435 #[test]
436 #[allow(clippy::iter_nth_zero)]
437 fn test_bitmap_iter_nth() {
438 {
440 let mut iter = BitmapIter::new(&[0b10110001], 0, 8);
441 assert_eq!(iter.nth(0), Some(true));
442 assert_eq!(iter.nth(0), Some(false));
443 assert_eq!(iter.nth(2), Some(true));
444 assert_eq!(iter.nth(3), None);
445
446 assert_eq!(iter.next(), None);
447 }
448
449 for len in [0, 1, 2, 63, 64, 65, 127, 128, 129] {
451 for offset in [0, 1, 2] {
452 let iter = BitmapIter::new(
454 &[
455 0, 1, 2, 4, 8, 16, 32, 64, 85, 170, 85, 170, 85, 170, 85, 170, 255, 0,
456 ],
457 offset,
458 len,
459 );
460
461 for i in 0..=len {
462 let mut iter_expected = iter.clone();
463 let mut iter_test = iter.clone();
464
465 let prev_rest_len = iter_test.rest_len;
466 let prev_word_len = iter_test.word_len;
467
468 assert_eq!(len, prev_rest_len + prev_word_len);
469
470 let out = iter_test.nth(i);
472 for _ in 0..i {
473 iter_expected.next();
474 }
475 let expected = iter_expected.next();
476
477 assert_eq!(out, expected);
479
480 let final_rest_len = iter_test.rest_len;
482 let final_word_len = iter_test.word_len;
483 match out {
484 Some(_) => assert_eq!(
485 prev_rest_len + prev_word_len,
486 i + 1 + final_rest_len + final_word_len
487 ),
488 None => {
489 assert!(i >= prev_rest_len + prev_word_len);
490 assert_eq!(final_rest_len + final_word_len, 0)
491 },
492 };
493 }
494 }
495 }
496
497 {
499 for len in [0, 63, 64, 65, 126, 128, 129] {
500 let mut iter =
501 BitmapIter::new(&[0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0], 0, len);
502 for step in [0, 1, 2, 3] {
503 for i in (0..len + step + 1).step_by(step + 1) {
504 let prev_rest_len = iter.rest_len;
505 let prev_word_len = iter.word_len;
506
507 let out = iter.nth(step);
508
509 let final_rest_len = iter.rest_len;
510 let final_word_len = iter.word_len;
511 match out {
512 Some(_) => assert_eq!(
513 prev_rest_len + prev_word_len,
514 step + 1 + final_rest_len + final_word_len
515 ),
516 None => {
517 assert!(i >= prev_rest_len + prev_word_len);
518 assert_eq!(final_rest_len + final_word_len, 0)
519 },
520 };
521 }
522 }
523 }
524 }
525
526 let mut iter = BitmapIter::new(&[], 0, 0);
528 assert_eq!(iter.nth(0), None);
529 }
530}