polars_arrow/bitmap/
bitmask.rs1#[cfg(feature = "simd")]
2use std::simd::{LaneCount, Mask, MaskElement, SupportedLaneCount};
3
4use polars_utils::slice::load_padded_le_u64;
5
6use super::iterator::FastU56BitmapIter;
7use super::utils::{BitmapIter, count_zeros, fmt};
8use crate::bitmap::Bitmap;
9
10#[inline]
13pub fn nth_set_bit_u32(w: u32, n: u32) -> Option<u32> {
14 #[cfg(all(not(miri), target_feature = "bmi2"))]
21 {
22 if n >= 32 {
23 return None;
24 }
25
26 let nth_set_bit = unsafe { core::arch::x86_64::_pdep_u32(1 << n, w) };
27 if nth_set_bit == 0 {
28 return None;
29 }
30
31 Some(nth_set_bit.trailing_zeros())
32 }
33
34 #[cfg(any(miri, not(target_feature = "bmi2")))]
35 {
36 let set_per_2 = w - ((w >> 1) & 0x55555555);
38 let set_per_4 = (set_per_2 & 0x33333333) + ((set_per_2 >> 2) & 0x33333333);
39 let set_per_8 = (set_per_4 + (set_per_4 >> 4)) & 0x0f0f0f0f;
40 let set_per_16 = (set_per_8 + (set_per_8 >> 8)) & 0x00ff00ff;
41 let set_per_32 = (set_per_16 + (set_per_16 >> 16)) & 0xff;
42 if n >= set_per_32 {
43 return None;
44 }
45
46 let mut idx = 0;
47 let mut n = n;
48 let next16 = set_per_16 & 0xff;
49 if n >= next16 {
50 n -= next16;
51 idx += 16;
52 }
53 let next8 = (set_per_8 >> idx) & 0xff;
54 if n >= next8 {
55 n -= next8;
56 idx += 8;
57 }
58 let next4 = (set_per_4 >> idx) & 0b1111;
59 if n >= next4 {
60 n -= next4;
61 idx += 4;
62 }
63 let next2 = (set_per_2 >> idx) & 0b11;
64 if n >= next2 {
65 n -= next2;
66 idx += 2;
67 }
68 let next1 = (w >> idx) & 0b1;
69 if n >= next1 {
70 idx += 1;
71 }
72 Some(idx)
73 }
74}
75
76#[derive(Default, Clone)]
77pub struct BitMask<'a> {
78 bytes: &'a [u8],
79 offset: usize,
80 len: usize,
81}
82
83impl std::fmt::Debug for BitMask<'_> {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 let Self { bytes, offset, len } = self;
86 let offset_num_bytes = offset / 8;
87 let offset_in_byte = offset % 8;
88 fmt(&bytes[offset_num_bytes..], offset_in_byte, *len, f)
89 }
90}
91
92impl<'a> BitMask<'a> {
93 pub fn from_bitmap(bitmap: &'a Bitmap) -> Self {
94 let (bytes, offset, len) = bitmap.as_slice();
95 Self::new(bytes, offset, len)
96 }
97
98 pub fn inner(&self) -> (&[u8], usize, usize) {
99 (self.bytes, self.offset, self.len)
100 }
101
102 pub fn new(bytes: &'a [u8], offset: usize, len: usize) -> Self {
103 assert!(bytes.len() * 8 >= len + offset);
105 Self { bytes, offset, len }
106 }
107
108 #[inline(always)]
109 pub fn len(&self) -> usize {
110 self.len
111 }
112
113 #[inline]
114 pub fn advance_by(&mut self, idx: usize) {
115 assert!(idx <= self.len);
116 self.offset += idx;
117 self.len -= idx;
118 }
119
120 #[inline]
121 pub fn split_at(&self, idx: usize) -> (Self, Self) {
122 assert!(idx <= self.len);
123 unsafe { self.split_at_unchecked(idx) }
124 }
125
126 #[inline]
129 pub unsafe fn split_at_unchecked(&self, idx: usize) -> (Self, Self) {
130 debug_assert!(idx <= self.len);
131 let left = Self { len: idx, ..*self };
132 let right = Self {
133 len: self.len - idx,
134 offset: self.offset + idx,
135 ..*self
136 };
137 (left, right)
138 }
139
140 #[inline]
141 pub fn sliced(&self, offset: usize, length: usize) -> Self {
142 assert!(offset.checked_add(length).unwrap() <= self.len);
143 unsafe { self.sliced_unchecked(offset, length) }
144 }
145
146 #[inline]
149 pub unsafe fn sliced_unchecked(&self, offset: usize, length: usize) -> Self {
150 if cfg!(debug_assertions) {
151 assert!(offset.checked_add(length).unwrap() <= self.len);
152 }
153
154 Self {
155 bytes: self.bytes,
156 offset: self.offset + offset,
157 len: length,
158 }
159 }
160
161 pub fn unset_bits(&self) -> usize {
162 count_zeros(self.bytes, self.offset, self.len)
163 }
164
165 pub fn set_bits(&self) -> usize {
166 self.len - self.unset_bits()
167 }
168
169 pub fn fast_iter_u56(&self) -> FastU56BitmapIter<'_> {
170 FastU56BitmapIter::new(self.bytes, self.offset, self.len)
171 }
172
173 #[cfg(feature = "simd")]
174 #[inline]
175 pub fn get_simd<T, const N: usize>(&self, idx: usize) -> Mask<T, N>
176 where
177 T: MaskElement,
178 LaneCount<N>: SupportedLaneCount,
179 {
180 let lanes = LaneCount::<N>::BITMASK_LEN;
184 assert!(lanes < 64);
185
186 let start_byte_idx = (self.offset + idx) / 8;
187 let byte_shift = (self.offset + idx) % 8;
188 if idx + lanes <= self.len {
189 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
191 Mask::from_bitmask(mask >> byte_shift)
192 } else if idx < self.len {
193 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
196 let num_out_of_bounds = idx + lanes - self.len;
197 let shifted = (mask << num_out_of_bounds) >> (num_out_of_bounds + byte_shift);
198 Mask::from_bitmask(shifted)
199 } else {
200 Mask::from_bitmask(0u64)
201 }
202 }
203
204 #[inline]
205 pub fn get_u32(&self, idx: usize) -> u32 {
206 let start_byte_idx = (self.offset + idx) / 8;
207 let byte_shift = (self.offset + idx) % 8;
208 if idx + 32 <= self.len {
209 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
211 (mask >> byte_shift) as u32
212 } else if idx < self.len {
213 let mask = load_padded_le_u64(unsafe { self.bytes.get_unchecked(start_byte_idx..) });
216 let out_of_bounds_mask = (1u32 << (self.len - idx)) - 1;
217 ((mask >> byte_shift) as u32) & out_of_bounds_mask
218 } else {
219 0
220 }
221 }
222
223 pub fn nth_set_bit_idx(&self, mut n: usize, mut start: usize) -> Option<usize> {
229 while start < self.len {
230 let next_u32_mask = self.get_u32(start);
231 if next_u32_mask == u32::MAX {
232 if n < 32 {
234 return Some(start + n);
235 }
236 n -= 32;
237 } else {
238 let ones = next_u32_mask.count_ones() as usize;
239 if n < ones {
240 let idx = unsafe {
241 nth_set_bit_u32(next_u32_mask, n as u32).unwrap_unchecked() as usize
243 };
244 return Some(start + idx);
245 }
246 n -= ones;
247 }
248
249 start += 32;
250 }
251
252 None
253 }
254
255 pub fn nth_set_bit_idx_rev(&self, mut n: usize, mut end: usize) -> Option<usize> {
261 while end > 0 {
262 let (u32_mask_start, u32_mask_mask) = if end >= 32 {
265 (end - 32, u32::MAX)
266 } else {
267 (0, (1 << end) - 1)
268 };
269 let next_u32_mask = self.get_u32(u32_mask_start) & u32_mask_mask;
270 if next_u32_mask == u32::MAX {
271 if n < 32 {
273 return Some(end - 1 - n);
274 }
275 n -= 32;
276 } else {
277 let ones = next_u32_mask.count_ones() as usize;
278 if n < ones {
279 let rev_n = ones - 1 - n;
280 let idx = unsafe {
281 nth_set_bit_u32(next_u32_mask, rev_n as u32).unwrap_unchecked() as usize
283 };
284 return Some(u32_mask_start + idx);
285 }
286 n -= ones;
287 }
288
289 end = u32_mask_start;
290 }
291
292 None
293 }
294
295 #[inline]
296 pub fn get(&self, idx: usize) -> bool {
297 let byte_idx = (self.offset + idx) / 8;
298 let byte_shift = (self.offset + idx) % 8;
299
300 if idx < self.len {
301 let byte = unsafe { *self.bytes.get_unchecked(byte_idx) };
303 (byte >> byte_shift) & 1 == 1
304 } else {
305 false
306 }
307 }
308
309 pub fn iter(&self) -> BitmapIter<'_> {
310 BitmapIter::new(self.bytes, self.offset, self.len)
311 }
312}
313
314#[cfg(test)]
315mod test {
316 use super::*;
317
318 fn naive_nth_bit_set(mut w: u32, mut n: u32) -> Option<u32> {
319 for i in 0..32 {
320 if w & (1 << i) != 0 {
321 if n == 0 {
322 return Some(i);
323 }
324 n -= 1;
325 w ^= 1 << i;
326 }
327 }
328 None
329 }
330
331 #[test]
332 fn test_nth_set_bit_u32() {
333 for n in 0..256 {
334 assert_eq!(nth_set_bit_u32(0, n), None);
335 }
336
337 for i in 0..32 {
338 assert_eq!(nth_set_bit_u32(1 << i, 0), Some(i));
339 assert_eq!(nth_set_bit_u32(1 << i, 1), None);
340 }
341
342 for i in 0..10000 {
343 let rnd = (0xbdbc9d8ec9d5c461u64.wrapping_mul(i as u64) >> 32) as u32;
344 for i in 0..=32 {
345 assert_eq!(nth_set_bit_u32(rnd, i), naive_nth_bit_set(rnd, i));
346 }
347 }
348 }
349}