1use hibitset::BitSetLike;
2
3#[derive(Debug, Clone)]
4pub struct BitIter<T> {
5 set: T,
6 masks: [usize; LAYERS],
7 prefix: [u32; LAYERS - 1],
8}
9
10impl<T> BitIter<T>
11where
12 T: BitSetLike,
13{
14 pub fn new(set: T) -> Self {
15 Self {
16 masks: [0, 0, 0, set.layer3()],
17 prefix: [0; 3],
18 set,
19 }
20 }
21
22 fn handle_next(&mut self, level: usize) -> State {
23 use self::State::*;
24
25 if self.masks[level] == 0 {
26 Empty
27 } else {
28 let first_bit = self.masks[level].trailing_zeros();
29 self.masks[level] &= !(1 << first_bit);
30
31 let idx = self.prefix.get(level).cloned().unwrap_or(0) | first_bit;
32
33 if level == 0 {
34 Value(idx)
35 } else {
36 self.masks[level - 1] = self.set.get_from_layer(level - 1, idx as usize);
37 self.prefix[level - 1] = idx << BITS;
38
39 Continue
40 }
41 }
42 }
43}
44
45impl<T> BitIter<T>
46where
47 T: BitSetLike + Copy,
48{
49 pub fn split(mut self) -> (Self, Option<Self>) {
50 let other = self
51 .handle_split(3)
52 .or_else(|| self.handle_split(2))
53 .or_else(|| self.handle_split(1));
54
55 (self, other)
56 }
57
58 fn handle_split(&mut self, level: usize) -> Option<Self> {
59 if self.masks[level] == 0 {
60 None
61 } else {
62 let level_prefix = self.prefix.get(level).cloned().unwrap_or(0);
63 let first_bit = self.masks[level].trailing_zeros();
64
65 bit_average(self.masks[level])
66 .map(|average_bit| {
67 let mask = (1 << average_bit) - 1;
68 let mut other = BitIter {
69 set: self.set,
70 masks: [0; LAYERS],
71 prefix: [0; LAYERS - 1],
72 };
73
74 other.masks[level] = self.masks[level] & !mask;
75 other.prefix[level - 1] = (level_prefix | average_bit as u32) << BITS;
76 other.prefix[level..].copy_from_slice(&self.prefix[level..]);
77
78 self.masks[level] &= mask;
79 self.prefix[level - 1] = (level_prefix | first_bit) << BITS;
80
81 other
82 })
83 .or_else(|| {
84 let idx = level_prefix as usize | first_bit as usize;
85
86 self.prefix[level - 1] = (idx as u32) << BITS;
87 self.masks[level] = 0;
88 self.masks[level - 1] = self.set.get_from_layer(level - 1, idx);
89
90 None
91 })
92 }
93 }
94}
95
96impl<T: BitSetLike> BitIter<T> {
97 pub fn contains(&self, i: u32) -> bool {
98 self.set.contains(i)
99 }
100}
101
102#[derive(PartialEq)]
103pub(crate) enum State {
104 Empty,
105 Continue,
106 Value(u32),
107}
108
109impl<T> Iterator for BitIter<T>
110where
111 T: BitSetLike,
112{
113 type Item = u32;
114
115 fn next(&mut self) -> Option<Self::Item> {
116 use self::State::*;
117
118 'find: loop {
119 for level in 0..LAYERS {
120 match self.handle_next(level) {
121 Value(v) => return Some(v),
122 Continue => continue 'find,
123 Empty => {}
124 }
125 }
126
127 return None;
128 }
129 }
130}
131
132impl<T: BitSetLike> BitIter<T> {}
133
134pub fn bit_average(n: usize) -> Option<usize> {
135 #[cfg(target_pointer_width = "64")]
136 let average = bit_average_u64(n as u64).map(|n| n as usize);
137
138 #[cfg(target_pointer_width = "32")]
139 let average = bit_average_u32(n as u32).map(|n| n as usize);
140
141 average
142}
143
144#[allow(clippy::many_single_char_names)]
145#[cfg(any(test, target_pointer_width = "32"))]
146fn bit_average_u32(n: u32) -> Option<u32> {
147 const PAR: [u32; 5] = [!0 / 0x3, !0 / 0x5, !0 / 0x11, !0 / 0x101, !0 / 0x10001];
148
149 let a = n - ((n >> 1) & PAR[0]);
150 let b = (a & PAR[1]) + ((a >> 2) & PAR[1]);
151 let c = (b + (b >> 4)) & PAR[2];
152 let d = (c + (c >> 8)) & PAR[3];
153
154 let mut cur = d >> 16;
155 let count = (d + cur) & PAR[4];
156
157 if count <= 1 {
158 return None;
159 }
160
161 let mut target = count / 2;
162 let mut result = 32;
163
164 {
165 let mut descend = |child, child_stride, child_mask| {
166 if cur < target {
167 result -= 2 * child_stride;
168 target -= cur;
169 }
170
171 cur = (child >> (result - child_stride)) & child_mask;
172 };
173
174 descend(c, 8, 16 - 1); descend(b, 4, 8 - 1); descend(a, 2, 4 - 1); descend(n, 1, 2 - 1); }
179
180 if cur < target {
181 result -= 1;
182 }
183
184 Some(result - 1)
185}
186
187#[allow(clippy::many_single_char_names)]
188#[cfg(any(test, target_pointer_width = "64"))]
189fn bit_average_u64(n: u64) -> Option<u64> {
190 const PAR: [u64; 6] = [
191 !0 / 0x3,
192 !0 / 0x5,
193 !0 / 0x11,
194 !0 / 0x101,
195 !0 / 0x10001,
196 !0 / 0x100000001,
197 ];
198
199 let a = n - ((n >> 1) & PAR[0]);
200 let b = (a & PAR[1]) + ((a >> 2) & PAR[1]);
201 let c = (b + (b >> 4)) & PAR[2];
202 let d = (c + (c >> 8)) & PAR[3];
203 let e = (d + (d >> 16)) & PAR[4];
204
205 let mut cur = e >> 32;
206 let count = (e + cur) & PAR[5];
207
208 if count <= 1 {
209 return None;
210 }
211
212 let mut target = count / 2;
213 let mut result = 64;
214
215 {
216 let mut descend = |child, child_stride, child_mask| {
217 if cur < target {
218 result -= 2 * child_stride;
219 target -= cur;
220 }
221
222 cur = (child >> (result - child_stride)) & child_mask;
223 };
224
225 descend(d, 16, 256 - 1); descend(c, 8, 16 - 1); descend(b, 4, 8 - 1); descend(a, 2, 4 - 1); descend(n, 1, 2 - 1); }
231
232 if cur < target {
233 result -= 1;
234 }
235
236 Some(result - 1)
237}
238
239const LAYERS: usize = 4;
240
241#[cfg(target_pointer_width = "64")]
242pub const BITS: usize = 6;
243
244#[cfg(target_pointer_width = "32")]
245pub const BITS: usize = 5;
246
247#[cfg(test)]
248mod test_bit_average {
249 use hibitset::{BitSet, BitSetLike};
250
251 use super::*;
252
253 #[test]
254 fn iterator_clone() {
255 let mut set = BitSet::new();
256
257 set.add(1);
258 set.add(3);
259
260 let iter = set.iter().skip(1);
261 for (a, b) in iter.clone().zip(iter) {
262 assert_eq!(a, b);
263 }
264 }
265
266 #[test]
267 fn parity_0_bit_average_u32() {
268 struct EvenParity(u32);
269
270 impl Iterator for EvenParity {
271 type Item = u32;
272 fn next(&mut self) -> Option<Self::Item> {
273 if self.0 == u32::max_value() {
274 return None;
275 }
276 self.0 += 1;
277 while self.0.count_ones() & 1 != 0 {
278 if self.0 == u32::max_value() {
279 return None;
280 }
281 self.0 += 1;
282 }
283 Some(self.0)
284 }
285 }
286
287 let steps = 1000;
288 for i in 0..steps {
289 let pos = i * (u32::max_value() / steps);
290 for i in EvenParity(pos).take(steps as usize) {
291 let mask = (1 << bit_average_u32(i).unwrap_or(31)) - 1;
292 assert_eq!((i & mask).count_ones(), (i & !mask).count_ones(), "{:x}", i);
293 }
294 }
295 }
296
297 #[test]
298 fn parity_1_bit_average_u32() {
299 struct OddParity(u32);
300
301 impl Iterator for OddParity {
302 type Item = u32;
303 fn next(&mut self) -> Option<Self::Item> {
304 if self.0 == u32::max_value() {
305 return None;
306 }
307 self.0 += 1;
308 while self.0.count_ones() & 1 == 0 {
309 if self.0 == u32::max_value() {
310 return None;
311 }
312 self.0 += 1;
313 }
314 Some(self.0)
315 }
316 }
317
318 let steps = 1000;
319 for i in 0..steps {
320 let pos = i * (u32::max_value() / steps);
321 for i in OddParity(pos).take(steps as usize) {
322 let mask = (1 << bit_average_u32(i).unwrap_or(31)) - 1;
323 let a = (i & mask).count_ones();
324 let b = (i & !mask).count_ones();
325 if a < b {
326 assert_eq!(a + 1, b, "{:x}", i);
327 } else if b < a {
328 assert_eq!(a, b + 1, "{:x}", i);
329 } else {
330 panic!("Odd parity shouldn't split in exactly half");
331 }
332 }
333 }
334 }
335
336 #[test]
337 fn empty_bit_average_u32() {
338 assert_eq!(None, bit_average_u32(0));
339 }
340
341 #[test]
342 fn singleton_bit_average_u32() {
343 for i in 0..32 {
344 assert_eq!(None, bit_average_u32(1 << i), "{:x}", i);
345 }
346 }
347
348 #[test]
349 fn parity_0_bit_average_u64() {
350 struct EvenParity(u64);
351
352 impl Iterator for EvenParity {
353 type Item = u64;
354 fn next(&mut self) -> Option<Self::Item> {
355 if self.0 == u64::max_value() {
356 return None;
357 }
358 self.0 += 1;
359 while self.0.count_ones() & 1 != 0 {
360 if self.0 == u64::max_value() {
361 return None;
362 }
363 self.0 += 1;
364 }
365 Some(self.0)
366 }
367 }
368
369 let steps = 1000;
370 for i in 0..steps {
371 let pos = i * (u64::max_value() / steps);
372 for i in EvenParity(pos).take(steps as usize) {
373 let mask = (1 << bit_average_u64(i).unwrap_or(63)) - 1;
374 assert_eq!((i & mask).count_ones(), (i & !mask).count_ones(), "{:x}", i);
375 }
376 }
377 }
378
379 #[test]
380 fn parity_1_bit_average_u64() {
381 struct OddParity(u64);
382
383 impl Iterator for OddParity {
384 type Item = u64;
385 fn next(&mut self) -> Option<Self::Item> {
386 if self.0 == u64::max_value() {
387 return None;
388 }
389 self.0 += 1;
390 while self.0.count_ones() & 1 == 0 {
391 if self.0 == u64::max_value() {
392 return None;
393 }
394 self.0 += 1;
395 }
396 Some(self.0)
397 }
398 }
399
400 let steps = 1000;
401 for i in 0..steps {
402 let pos = i * (u64::max_value() / steps);
403 for i in OddParity(pos).take(steps as usize) {
404 let mask = (1 << bit_average_u64(i).unwrap_or(63)) - 1;
405 let a = (i & mask).count_ones();
406 let b = (i & !mask).count_ones();
407 if a < b {
408 assert_eq!(a + 1, b, "{:x}", i);
409 } else if b < a {
410 assert_eq!(a, b + 1, "{:x}", i);
411 } else {
412 panic!("Odd parity shouldn't split in exactly half");
413 }
414 }
415 }
416 }
417
418 #[test]
419 fn empty_bit_average_u64() {
420 assert_eq!(None, bit_average_u64(0));
421 }
422
423 #[test]
424 fn singleton_bit_average_u64() {
425 for i in 0..64 {
426 assert_eq!(None, bit_average_u64(1 << i), "{:x}", i);
427 }
428 }
429
430 #[test]
431 fn bit_average_agree_u32_u64() {
432 let steps = 1000;
433 for i in 0..steps {
434 let pos = i * (u32::max_value() / steps);
435 for i in pos..steps {
436 assert_eq!(
437 bit_average_u32(i),
438 bit_average_u64(i as u64).map(|n| n as u32),
439 "{:x}",
440 i
441 );
442 }
443 }
444 }
445
446 #[test]
447 fn specific_values() {
448 assert_eq!(Some(4), bit_average_u32(0b10110));
449 assert_eq!(Some(5), bit_average_u32(0b100010));
450 assert_eq!(None, bit_average_u32(0));
451 assert_eq!(None, bit_average_u32(1));
452
453 assert_eq!(Some(4), bit_average_u64(0b10110));
454 assert_eq!(Some(5), bit_average_u64(0b100010));
455 assert_eq!(None, bit_average_u64(0));
456 assert_eq!(None, bit_average_u64(1));
457 }
458}