polars_arrow/bitmap/utils/
iterator.rs1use polars_utils::slice::load_padded_le_u64;
2
3use super::get_bit_unchecked;
4use crate::bitmap::MutableBitmap;
5use crate::trusted_len::TrustedLen;
6
7#[derive(Debug, Clone)]
10pub struct BitmapIter<'a> {
11 bytes: &'a [u8],
12 word: u64,
13 word_len: usize,
14 rest_len: usize,
15}
16
17impl<'a> BitmapIter<'a> {
18 pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
20 if len == 0 {
21 return Self {
22 bytes,
23 word: 0,
24 word_len: 0,
25 rest_len: 0,
26 };
27 }
28
29 assert!(bytes.len() * 8 >= offset + len);
30 let first_byte_idx = offset / 8;
31 let bytes = &bytes[first_byte_idx..];
32 let offset = offset % 8;
33
34 let word = load_padded_le_u64(bytes) >> offset;
37 let mod8 = bytes.len() % 8;
38 let first_word_bytes = if mod8 > 0 { mod8 } else { 8 };
39 let bytes = &bytes[first_word_bytes..];
40
41 let word_len = (first_word_bytes * 8 - offset).min(len);
42 let rest_len = len - word_len;
43 Self {
44 bytes,
45 word,
46 word_len,
47 rest_len,
48 }
49 }
50
51 pub fn take_leading_ones(&mut self) -> usize {
58 let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
59 self.word_len -= word_ones;
60 self.word = self.word.wrapping_shr(word_ones as u32);
61
62 if self.word_len != 0 {
63 return word_ones;
64 }
65
66 let mut num_leading_ones = word_ones;
67
68 while self.rest_len != 0 {
69 self.word_len = usize::min(self.rest_len, 64);
70 self.rest_len -= self.word_len;
71
72 unsafe {
73 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
74 self.word = u64::from_le_bytes(chunk);
75 self.bytes = self.bytes.get_unchecked(8..);
76 }
77
78 let word_ones = usize::min(self.word_len, self.word.trailing_ones() as usize);
79 self.word_len -= word_ones;
80 self.word = self.word.wrapping_shr(word_ones as u32);
81 num_leading_ones += word_ones;
82
83 if self.word_len != 0 {
84 return num_leading_ones;
85 }
86 }
87
88 num_leading_ones
89 }
90
91 pub fn take_leading_zeros(&mut self) -> usize {
98 let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
99 self.word_len -= word_zeros;
100 self.word = self.word.wrapping_shr(word_zeros as u32);
101
102 if self.word_len != 0 {
103 return word_zeros;
104 }
105
106 let mut num_leading_zeros = word_zeros;
107
108 while self.rest_len != 0 {
109 self.word_len = usize::min(self.rest_len, 64);
110 self.rest_len -= self.word_len;
111 unsafe {
112 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
113 self.word = u64::from_le_bytes(chunk);
114 self.bytes = self.bytes.get_unchecked(8..);
115 }
116
117 let word_zeros = usize::min(self.word_len, self.word.trailing_zeros() as usize);
118 self.word_len -= word_zeros;
119 self.word = self.word.wrapping_shr(word_zeros as u32);
120 num_leading_zeros += word_zeros;
121
122 if self.word_len != 0 {
123 return num_leading_zeros;
124 }
125 }
126
127 num_leading_zeros
128 }
129
130 #[inline]
132 pub fn num_remaining(&self) -> usize {
133 self.word_len + self.rest_len
134 }
135
136 pub fn collect_n_into(&mut self, bitmap: &mut MutableBitmap, n: usize) {
138 fn collect_word(
139 word: &mut u64,
140 word_len: &mut usize,
141 bitmap: &mut MutableBitmap,
142 n: &mut usize,
143 ) {
144 while *n > 0 && *word_len > 0 {
145 {
146 let trailing_ones = u32::min(word.trailing_ones(), *word_len as u32);
147 let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_ones);
148 *word = word.wrapping_shr(shift);
149 *word_len -= shift as usize;
150 *n -= shift as usize;
151
152 bitmap.extend_constant(shift as usize, true);
153 }
154
155 {
156 let trailing_zeros = u32::min(word.trailing_zeros(), *word_len as u32);
157 let shift = u32::min(usize::min(*n, u32::MAX as usize) as u32, trailing_zeros);
158 *word = word.wrapping_shr(shift);
159 *word_len -= shift as usize;
160 *n -= shift as usize;
161
162 bitmap.extend_constant(shift as usize, false);
163 }
164 }
165 }
166
167 let mut n = usize::min(n, self.num_remaining());
168 bitmap.reserve(n);
169
170 collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
171
172 if n == 0 {
173 return;
174 }
175
176 let num_words = n / 64;
177
178 if num_words > 0 {
179 assert!(self.bytes.len() >= num_words * size_of::<u64>());
180
181 bitmap.extend_from_slice(self.bytes, 0, num_words * u64::BITS as usize);
182
183 self.bytes = unsafe { self.bytes.get_unchecked(num_words * 8..) };
184 self.rest_len -= num_words * u64::BITS as usize;
185 n -= num_words * u64::BITS as usize;
186 }
187
188 if n == 0 {
189 return;
190 }
191
192 assert!(self.bytes.len() >= size_of::<u64>());
193
194 self.word_len = usize::min(self.rest_len, 64);
195 self.rest_len -= self.word_len;
196 unsafe {
197 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
198 self.word = u64::from_le_bytes(chunk);
199 self.bytes = self.bytes.get_unchecked(8..);
200 }
201
202 collect_word(&mut self.word, &mut self.word_len, bitmap, &mut n);
203
204 debug_assert!(self.num_remaining() == 0 || n == 0);
205 }
206}
207
208impl Iterator for BitmapIter<'_> {
209 type Item = bool;
210
211 #[inline]
212 fn next(&mut self) -> Option<Self::Item> {
213 if self.word_len == 0 {
214 if self.rest_len == 0 {
215 return None;
216 }
217
218 self.word_len = self.rest_len.min(64);
219 self.rest_len -= self.word_len;
220
221 unsafe {
222 let chunk = self.bytes.get_unchecked(..8).try_into().unwrap();
223 self.word = u64::from_le_bytes(chunk);
224 self.bytes = self.bytes.get_unchecked(8..);
225 }
226 }
227
228 let ret = self.word & 1 != 0;
229 self.word >>= 1;
230 self.word_len -= 1;
231 Some(ret)
232 }
233
234 #[inline]
235 fn size_hint(&self) -> (usize, Option<usize>) {
236 let num_remaining = self.num_remaining();
237 (num_remaining, Some(num_remaining))
238 }
239}
240
241impl DoubleEndedIterator for BitmapIter<'_> {
242 #[inline]
243 fn next_back(&mut self) -> Option<bool> {
244 if self.rest_len > 0 {
245 self.rest_len -= 1;
246 Some(unsafe { get_bit_unchecked(self.bytes, self.rest_len) })
247 } else if self.word_len > 0 {
248 self.word_len -= 1;
249 Some(self.word & (1 << self.word_len) != 0)
250 } else {
251 None
252 }
253 }
254}
255
256unsafe impl TrustedLen for BitmapIter<'_> {}
257impl ExactSizeIterator for BitmapIter<'_> {}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_collect_into_17579() {
265 let mut bitmap = MutableBitmap::with_capacity(64);
266 BitmapIter::new(&[0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0], 0, 128)
267 .collect_n_into(&mut bitmap, 129);
268
269 let bitmap = bitmap.freeze();
270
271 assert_eq!(bitmap.set_bits(), 4);
272 }
273
274 #[test]
275 #[ignore = "Fuzz test. Too slow"]
276 fn test_fuzz_collect_into() {
277 for _ in 0..10_000 {
278 let mut set_bits = 0;
279 let mut unset_bits = 0;
280
281 let mut length = 0;
282 let mut pattern = Vec::new();
283 for _ in 0..rand::random::<usize>() % 1024 {
284 let bs = rand::random::<u8>() % 4;
285
286 let word = match bs {
287 0 => u64::MIN,
288 1 => u64::MAX,
289 2 | 3 => rand::random(),
290 _ => unreachable!(),
291 };
292
293 pattern.extend_from_slice(&word.to_le_bytes());
294 set_bits += word.count_ones();
295 unset_bits += word.count_zeros();
296 length += 64;
297 }
298
299 for _ in 0..rand::random::<usize>() % 7 {
300 let b = rand::random::<u8>();
301 pattern.push(b);
302 set_bits += b.count_ones();
303 unset_bits += b.count_zeros();
304 length += 8;
305 }
306
307 let last_length = rand::random::<usize>() % 8;
308 if last_length != 0 {
309 let b = rand::random::<u8>();
310 pattern.push(b);
311 let ones = (b & ((1 << last_length) - 1)).count_ones();
312 set_bits += ones;
313 unset_bits += last_length as u32 - ones;
314 length += last_length;
315 }
316
317 let mut iter = BitmapIter::new(&pattern, 0, length);
318 let mut bitmap = MutableBitmap::with_capacity(length);
319
320 while iter.num_remaining() > 0 {
321 let len_before = bitmap.len();
322 let n = rand::random::<usize>() % iter.num_remaining();
323 iter.collect_n_into(&mut bitmap, n);
324
325 assert_eq!(bitmap.len(), len_before + n);
327 }
328
329 let bitmap = bitmap.freeze();
330
331 assert_eq!(bitmap.set_bits(), set_bits as usize);
332 assert_eq!(bitmap.unset_bits(), unset_bits as usize);
333 }
334 }
335
336 #[test]
337 #[ignore = "Fuzz test. Too slow"]
338 fn test_fuzz_leading_ops() {
339 for _ in 0..10_000 {
340 let mut length = 0;
341 let mut pattern = Vec::new();
342 for _ in 0..rand::random::<usize>() % 1024 {
343 let bs = rand::random::<u8>() % 4;
344
345 let word = match bs {
346 0 => u64::MIN,
347 1 => u64::MAX,
348 2 | 3 => rand::random(),
349 _ => unreachable!(),
350 };
351
352 pattern.extend_from_slice(&word.to_le_bytes());
353 length += 64;
354 }
355
356 for _ in 0..rand::random::<usize>() % 7 {
357 pattern.push(rand::random::<u8>());
358 length += 8;
359 }
360
361 let last_length = rand::random::<usize>() % 8;
362 if last_length != 0 {
363 pattern.push(rand::random::<u8>());
364 length += last_length;
365 }
366
367 let mut iter = BitmapIter::new(&pattern, 0, length);
368
369 let mut prev_remaining = iter.num_remaining();
370 while iter.num_remaining() != 0 {
371 let num_ones = iter.clone().take_leading_ones();
372 assert_eq!(num_ones, (&mut iter).take_while(|&b| b).count());
373
374 let num_zeros = iter.clone().take_leading_zeros();
375 assert_eq!(num_zeros, (&mut iter).take_while(|&b| !b).count());
376
377 assert!(iter.num_remaining() < prev_remaining);
379 prev_remaining = iter.num_remaining();
380 }
381
382 assert_eq!(iter.take_leading_zeros(), 0);
383 assert_eq!(iter.take_leading_ones(), 0);
384 }
385 }
386}