1#[cfg(feature = "simd")]
2use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4use polars_utils::slice::load_padded_le_u64;
5
6use super::iterator::FastU56BitmapIter;
7use super::utils::{self, BitChunk, BitChunks, BitmapIter, count_zeros, fmt};
8use crate::bitmap::Bitmap;
9
10#[inline]
13pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
14 #[cfg(all(not(miri), target_feature = "bmi2"))]
21 {
22 if n >= 32 {
23 return None;
24 }
25
26 let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
27 if nth_set_bit == 0 {
28 return None;
29 }
30
31 Some(nth_set_bit.trailing_zeros())
32 }
33
34 #[cfg(any(miri, not(target_feature = "bmi2")))]
35 {
36 let set_per_2 = w - ((w >> 1) & 0x55555555);
38 let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
39 let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
40 let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
41 let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
42
43 if n >= set_per_32 {
44 return None;
45 }
46
47 let mut idx = 0;
48 let mut n = n;
49
50 let next16 = set_per_16 & 0xff;
51 if n >= next16 {
52 n -= next16;
53 idx += 16;
54 }
55 let next8 = (set_per_8 >> idx) & 0xff;
56 if n >= next8 {
57 n -= next8;
58 idx += 8;
59 }
60 let next4 = (set_per_4 >> idx) & 0b1111;
61 if n >= next4 {
62 n -= next4;
63 idx += 4;
64 }
65 let next2 = (set_per_2 >> idx) & 0b11;
66 if n >= next2 {
67 n -= next2;
68 idx += 2;
69 }
70 let next1 = (w >> idx) & 0b1;
71 if n >= next1 {
72 idx += 1;
73 }
74 Some(idx)
75 }
76}
77
78#[inline]
79pub fn nth_set_bit_u64(w: u64, n: u64) -> Option<u64> {
80 #[cfg(all(not(miri), target_feature = "bmi2"))]
81 {
82 if n >= 64 {
83 return None;
84 }
85
86 let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u64(1 << n, w) };
87 if nth_set_bit == 0 {
88 return None;
89 }
90
91 Some(nth_set_bit.trailing_zeros().into())
92 }
93
94 #[cfg(any(miri, not(target_feature = "bmi2")))]
95 {
96 let set_per_2 = w - ((w >> 1) & 0x5555555555555555);
98 let set_per_4 = (set_per_2 & 0x3333333333333333) + ((set_per_2 >> 2) & 0x3333333333333333);
99 let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f0f0f0f0f;
100 let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff00ff00ff;
101 let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0x0000ffff0000ffff;
102 let set_per_64 = (set_per_32 + (set_per_32 >> 32)) & 0xffffffff;
103
104 if n >= set_per_64 {
105 return None;
106 }
107
108 let mut idx = 0;
109 let mut n = n;
110
111 let next32 = set_per_32 & 0xffff;
112 if n >= next32 {
113 n -= next32;
114 idx += 32;
115 }
116 let next16 = (set_per_16 >> idx) & 0xffff;
117 if n >= next16 {
118 n -= next16;
119 idx += 16;
120 }
121 let next8 = (set_per_8 >> idx) & 0xff;
122 if n >= next8 {
123 n -= next8;
124 idx += 8;
125 }
126 let next4 = (set_per_4 >> idx) & 0b1111;
127 if n >= next4 {
128 n -= next4;
129 idx += 4;
130 }
131 let next2 = (set_per_2 >> idx) & 0b11;
132 if n >= next2 {
133 n -= next2;
134 idx += 2;
135 }
136 let next1 = (w >> idx) & 0b1;
137 if n >= next1 {
138 idx += 1;
139 }
140 Some(idx)
141 }
142}
143
144#[derive(Default, Clone, Copy)]
145pub struct BitMask<'a> {
146 bytes: &'a [u8],
147 offset: usize,
148 len: usize,
149}
150
151impl std::fmt::Debug for BitMask<'_> {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 let Self { bytes, offset, len } = self;
154 let offset_num_bytes = offset / 8;
155 let offset_in_byte = offset % 8;
156 fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
157 }
158}
159
160impl<'a> BitMask<'a> {
161 pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
162 let (bytes, offset, len) = bitmap.as_slice();
163 Self::new(bytes, offset, len)
164 }
165
166 pub fn inner(&self) -> (&[u8], usize, usize) {
167 (self.bytes, self.offset, self.len)
168 }
169
170 pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
171 assert!(bytes.len() * 8 >= len + offset);
173 Self { bytes, offset, len }
174 }
175
176 #[inline(always)]
177 pub fn len(&self) -> usize {
178 self.len
179 }
180
181 #[inline]
182 pub fn advance_by(&mut self, idx: usize) {
183 assert!(idx <= self.len);
184 self.offset += idx;
185 self.len -= idx;
186 }
187
188 #[inline]
189 pub fn split_at(&self, idx: usize) -> (Self, Self) {
190 assert!(idx <= self.len);
191 unsafe { self.split_at_unchecked(idx) }
192 }
193
194 #[inline]
197 pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
198 debug_assert!(idx <= self.len);
199 let left = Self { len: idx, ..*self };
200 let right = Self {
201 len: self.len - idx,
202 offset: self.offset + idx,
203 ..*self
204 };
205 (left, right)
206 }
207
208 #[inline]
209 pub fn sliced(&self, offset: usize, length: usize) -> Self {
210 assert!(offset.checked_add(length).unwrap() <= self.len);
211 unsafe { self.sliced_unchecked(offset, length) }
212 }
213
214 #[inline]
217 pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
218 if cfg!(debug_assertions) {
219 assert!(offset.checked_add(length).unwrap() <= self.len);
220 }
221
222 Self {
223 bytes: self.bytes,
224 offset: self.offset + offset,
225 len: length,
226 }
227 }
228
229 pub fn unset_bits(&self) -> usize {
230 count_zeros(self.bytes, self.offset, self.len)
231 }
232
233 pub fn set_bits(&self) -> usize {
234 self.len - self.unset_bits()
235 }
236
237 pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {
238 FastU56BitmapIter::new(self.bytes, self.offset, self.len)
239 }
240
241 #[cfg(feature = "simd")]
242 #[inline]
243 pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
244 where
245 T: MaskElement,
246 LaneCount<N>: SupportedLaneCount,
247 {
248 let lanes = LaneCount::<N>::BITMASK_LEN;
252 assert!(lanes < 64);
253
254 let start_byte_idx = (self.offset + idx) / 8;
255 let byte_shift = (self.offset + idx) % 8;
256 if idx + lanes <= self.len {
257 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
259 Mask::from_bitmask(mask >> byte_shift)
260 } else if idx < self.len {
261 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
264 let num_out_of_bounds = idx + lanes - self.len;
265 let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
266 Mask::from_bitmask(shifted)
267 } else {
268 Mask::from_bitmask(0u64)
269 }
270 }
271
272 #[inline]
273 pub fn get_u32(&self, idx: usize) -> u32 {
274 let start_byte_idx = (self.offset + idx) / 8;
275 let byte_shift = (self.offset + idx) % 8;
276 if idx + 32 <= self.len {
277 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
279 (mask >> byte_shift) as u32
280 } else if idx < self.len {
281 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
284 let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
285 ((mask >> byte_shift) as u32) & out_of_bounds_mask
286 } else {
287 0
288 }
289 }
290
291 pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
297 while start < self.len {
298 let next_u32_mask = self.get_u32(start);
299 if next_u32_mask == u32::MAX {
300 if n < 32 {
302 return Some(start + n);
303 }
304 n -= 32;
305 } else {
306 let ones = next_u32_mask.count_ones() as usize;
307 if n < ones {
308 let idx = unsafe {
309 nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
311 };
312 return Some(start + idx);
313 }
314 n -= ones;
315 }
316
317 start += 32;
318 }
319
320 None
321 }
322
323 pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
329 while end > 0 {
330 let (u32_mask_start, u32_mask_mask) = if end >= 32 {
333 (end - 32, u32::MAX)
334 } else {
335 (0, (1 << end) - 1)
336 };
337 let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
338 if next_u32_mask == u32::MAX {
339 if n < 32 {
341 return Some(end - 1 - n);
342 }
343 n -= 32;
344 } else {
345 let ones = next_u32_mask.count_ones() as usize;
346 if n < ones {
347 let rev_n = ones - 1 - n;
348 let idx = unsafe {
349 nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
351 };
352 return Some(u32_mask_start + idx);
353 }
354 n -= ones;
355 }
356
357 end = u32_mask_start;
358 }
359
360 None
361 }
362
363 #[inline]
364 pub fn get(&self, idx: usize) -> bool {
365 if idx < self.len {
366 unsafe { self.get_bit_unchecked(idx) }
368 } else {
369 false
370 }
371 }
372
373 #[inline]
374 pub unsafe fn get_bit_unchecked(&self, idx: usize) -> bool {
380 let byte_idx = (self.offset + idx) / 8;
381 let byte_shift = (self.offset + idx) % 8;
382
383 let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
385 (byte >> byte_shift) & 1 == 1
386 }
387
388 pub fn iter(self) -> BitmapIter<'a> {
389 BitmapIter::new(self.bytes, self.offset, self.len)
390 }
391
392 pub fn leading_zeros(self) -> usize {
394 utils::leading_zeros(self.bytes, self.offset, self.len)
395 }
396 pub fn leading_ones(self) -> usize {
398 utils::leading_ones(self.bytes, self.offset, self.len)
399 }
400 pub fn trailing_zeros(self) -> usize {
402 utils::trailing_zeros(self.bytes, self.offset, self.len)
403 }
404 pub fn trailing_ones(self) -> usize {
406 utils::trailing_ones(self.bytes, self.offset, self.len)
407 }
408
409 pub fn intersects_with(self, other: Self) -> bool {
413 self.num_intersections_with(other) != 0
414 }
415
416 pub fn num_intersections_with(self, other: Self) -> usize {
418 super::num_intersections_with(self, other)
419 }
420
421 pub fn chunks<T: BitChunk>(self) -> BitChunks<'a, T> {
425 BitChunks::new(self.bytes, self.offset, self.len)
426 }
427}
428
429#[cfg(test)]
430mod test {
431 use super::*;
432
433 fn naive_nth_bit_set_u32(mut w: u32, mut n: u32) -> Option<u32> {
434 for i in 0..32 {
435 if w & (1 << i) != 0 {
436 if n == 0 {
437 return Some(i);
438 }
439 n -= 1;
440 w ^= 1 << i;
441 }
442 }
443 None
444 }
445
446 fn naive_nth_bit_set_u64(mut w: u64, mut n: u64) -> Option<u64> {
447 for i in 0..64 {
448 if w & (1 << i) != 0 {
449 if n == 0 {
450 return Some(i);
451 }
452 n -= 1;
453 w ^= 1 << i;
454 }
455 }
456 None
457 }
458
459 #[test]
460 fn test_nth_set_bit_u32() {
461 for n in 0..256 {
462 assert_eq!(nth_set_bit_u32(0, n), None);
463 }
464
465 for i in 0..32 {
466 assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
467 assert_eq!(nth_set_bit_u32(1 << i, 1), None);
468 }
469
470 for i in 0..10000 {
471 let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
472 for i in 0..=32 {
473 assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set_u32(rnd, i));
474 }
475 }
476 }
477
478 #[test]
479 fn test_nth_set_bit_u64() {
480 for n in 0..256 {
481 assert_eq!(nth_set_bit_u64(0, n), None);
482 }
483
484 for i in 0..64 {
485 assert_eq!(nth_set_bit_u64(1 << i, 0), Some(i));
486 assert_eq!(nth_set_bit_u64(1 << i, 1), None);
487 }
488
489 for i in 0..10000 {
490 let rnd = 0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32;
491 for i in 0..=64 {
492 assert_eq!(nth_set_bit_u64(rnd, i), naive_nth_bit_set_u64(rnd, i));
493 }
494 }
495 }
496}