async_ecs/misc/bit/
iter.rs

1use hibitset::BitSetLike;
2
3#[derive(Debug, Clone)]
4pub struct BitIter<T> {
5    set: T,
6    masks: [usize; LAYERS],
7    prefix: [u32; LAYERS - 1],
8}
9
10impl<T> BitIter<T>
11where
12    T: BitSetLike,
13{
14    pub fn new(set: T) -> Self {
15        Self {
16            masks: [0, 0, 0, set.layer3()],
17            prefix: [0; 3],
18            set,
19        }
20    }
21
22    fn handle_next(&mut self, level: usize) -> State {
23        use self::State::*;
24
25        if self.masks[level] == 0 {
26            Empty
27        } else {
28            let first_bit = self.masks[level].trailing_zeros();
29            self.masks[level] &= !(1 << first_bit);
30
31            let idx = self.prefix.get(level).cloned().unwrap_or(0) | first_bit;
32
33            if level == 0 {
34                Value(idx)
35            } else {
36                self.masks[level - 1] = self.set.get_from_layer(level - 1, idx as usize);
37                self.prefix[level - 1] = idx << BITS;
38
39                Continue
40            }
41        }
42    }
43}
44
45impl<T> BitIter<T>
46where
47    T: BitSetLike + Copy,
48{
49    pub fn split(mut self) -> (Self, Option<Self>) {
50        let other = self
51            .handle_split(3)
52            .or_else(|| self.handle_split(2))
53            .or_else(|| self.handle_split(1));
54
55        (self, other)
56    }
57
58    fn handle_split(&mut self, level: usize) -> Option<Self> {
59        if self.masks[level] == 0 {
60            None
61        } else {
62            let level_prefix = self.prefix.get(level).cloned().unwrap_or(0);
63            let first_bit = self.masks[level].trailing_zeros();
64
65            bit_average(self.masks[level])
66                .map(|average_bit| {
67                    let mask = (1 << average_bit) - 1;
68                    let mut other = BitIter {
69                        set: self.set,
70                        masks: [0; LAYERS],
71                        prefix: [0; LAYERS - 1],
72                    };
73
74                    other.masks[level] = self.masks[level] & !mask;
75                    other.prefix[level - 1] = (level_prefix | average_bit as u32) << BITS;
76                    other.prefix[level..].copy_from_slice(&self.prefix[level..]);
77
78                    self.masks[level] &= mask;
79                    self.prefix[level - 1] = (level_prefix | first_bit) << BITS;
80
81                    other
82                })
83                .or_else(|| {
84                    let idx = level_prefix as usize | first_bit as usize;
85
86                    self.prefix[level - 1] = (idx as u32) << BITS;
87                    self.masks[level] = 0;
88                    self.masks[level - 1] = self.set.get_from_layer(level - 1, idx);
89
90                    None
91                })
92        }
93    }
94}
95
96impl<T: BitSetLike> BitIter<T> {
97    pub fn contains(&self, i: u32) -> bool {
98        self.set.contains(i)
99    }
100}
101
102#[derive(PartialEq)]
103pub(crate) enum State {
104    Empty,
105    Continue,
106    Value(u32),
107}
108
109impl<T> Iterator for BitIter<T>
110where
111    T: BitSetLike,
112{
113    type Item = u32;
114
115    fn next(&mut self) -> Option<Self::Item> {
116        use self::State::*;
117
118        'find: loop {
119            for level in 0..LAYERS {
120                match self.handle_next(level) {
121                    Value(v) => return Some(v),
122                    Continue => continue 'find,
123                    Empty => {}
124                }
125            }
126
127            return None;
128        }
129    }
130}
131
132impl<T: BitSetLike> BitIter<T> {}
133
134pub fn bit_average(n: usize) -> Option<usize> {
135    #[cfg(target_pointer_width = "64")]
136    let average = bit_average_u64(n as u64).map(|n| n as usize);
137
138    #[cfg(target_pointer_width = "32")]
139    let average = bit_average_u32(n as u32).map(|n| n as usize);
140
141    average
142}
143
144#[allow(clippy::many_single_char_names)]
145#[cfg(any(test, target_pointer_width = "32"))]
146fn bit_average_u32(n: u32) -> Option<u32> {
147    const PAR: [u32; 5] = [!0 / 0x3, !0 / 0x5, !0 / 0x11, !0 / 0x101, !0 / 0x10001];
148
149    let a = n - ((n >> 1) & PAR[0]);
150    let b = (a & PAR[1]) + ((a >> 2) & PAR[1]);
151    let c = (b + (b >> 4)) & PAR[2];
152    let d = (c + (c >> 8)) & PAR[3];
153
154    let mut cur = d >> 16;
155    let count = (d + cur) & PAR[4];
156
157    if count <= 1 {
158        return None;
159    }
160
161    let mut target = count / 2;
162    let mut result = 32;
163
164    {
165        let mut descend = |child, child_stride, child_mask| {
166            if cur < target {
167                result -= 2 * child_stride;
168                target -= cur;
169            }
170
171            cur = (child >> (result - child_stride)) & child_mask;
172        };
173
174        descend(c, 8, 16 - 1); // PAR[3]
175        descend(b, 4, 8 - 1); // PAR[2]
176        descend(a, 2, 4 - 1); // PAR[1]
177        descend(n, 1, 2 - 1); // PAR[0]
178    }
179
180    if cur < target {
181        result -= 1;
182    }
183
184    Some(result - 1)
185}
186
187#[allow(clippy::many_single_char_names)]
188#[cfg(any(test, target_pointer_width = "64"))]
189fn bit_average_u64(n: u64) -> Option<u64> {
190    const PAR: [u64; 6] = [
191        !0 / 0x3,
192        !0 / 0x5,
193        !0 / 0x11,
194        !0 / 0x101,
195        !0 / 0x10001,
196        !0 / 0x100000001,
197    ];
198
199    let a = n - ((n >> 1) & PAR[0]);
200    let b = (a & PAR[1]) + ((a >> 2) & PAR[1]);
201    let c = (b + (b >> 4)) & PAR[2];
202    let d = (c + (c >> 8)) & PAR[3];
203    let e = (d + (d >> 16)) & PAR[4];
204
205    let mut cur = e >> 32;
206    let count = (e + cur) & PAR[5];
207
208    if count <= 1 {
209        return None;
210    }
211
212    let mut target = count / 2;
213    let mut result = 64;
214
215    {
216        let mut descend = |child, child_stride, child_mask| {
217            if cur < target {
218                result -= 2 * child_stride;
219                target -= cur;
220            }
221
222            cur = (child >> (result - child_stride)) & child_mask;
223        };
224
225        descend(d, 16, 256 - 1); // PAR[4]
226        descend(c, 8, 16 - 1); // PAR[3]
227        descend(b, 4, 8 - 1); // PAR[2]
228        descend(a, 2, 4 - 1); // PAR[1]
229        descend(n, 1, 2 - 1); // PAR[0]
230    }
231
232    if cur < target {
233        result -= 1;
234    }
235
236    Some(result - 1)
237}
238
239const LAYERS: usize = 4;
240
241#[cfg(target_pointer_width = "64")]
242pub const BITS: usize = 6;
243
244#[cfg(target_pointer_width = "32")]
245pub const BITS: usize = 5;
246
247#[cfg(test)]
248mod test_bit_average {
249    use hibitset::{BitSet, BitSetLike};
250
251    use super::*;
252
253    #[test]
254    fn iterator_clone() {
255        let mut set = BitSet::new();
256
257        set.add(1);
258        set.add(3);
259
260        let iter = set.iter().skip(1);
261        for (a, b) in iter.clone().zip(iter) {
262            assert_eq!(a, b);
263        }
264    }
265
266    #[test]
267    fn parity_0_bit_average_u32() {
268        struct EvenParity(u32);
269
270        impl Iterator for EvenParity {
271            type Item = u32;
272            fn next(&mut self) -> Option<Self::Item> {
273                if self.0 == u32::max_value() {
274                    return None;
275                }
276                self.0 += 1;
277                while self.0.count_ones() & 1 != 0 {
278                    if self.0 == u32::max_value() {
279                        return None;
280                    }
281                    self.0 += 1;
282                }
283                Some(self.0)
284            }
285        }
286
287        let steps = 1000;
288        for i in 0..steps {
289            let pos = i * (u32::max_value() / steps);
290            for i in EvenParity(pos).take(steps as usize) {
291                let mask = (1 << bit_average_u32(i).unwrap_or(31)) - 1;
292                assert_eq!((i & mask).count_ones(), (i & !mask).count_ones(), "{:x}", i);
293            }
294        }
295    }
296
297    #[test]
298    fn parity_1_bit_average_u32() {
299        struct OddParity(u32);
300
301        impl Iterator for OddParity {
302            type Item = u32;
303            fn next(&mut self) -> Option<Self::Item> {
304                if self.0 == u32::max_value() {
305                    return None;
306                }
307                self.0 += 1;
308                while self.0.count_ones() & 1 == 0 {
309                    if self.0 == u32::max_value() {
310                        return None;
311                    }
312                    self.0 += 1;
313                }
314                Some(self.0)
315            }
316        }
317
318        let steps = 1000;
319        for i in 0..steps {
320            let pos = i * (u32::max_value() / steps);
321            for i in OddParity(pos).take(steps as usize) {
322                let mask = (1 << bit_average_u32(i).unwrap_or(31)) - 1;
323                let a = (i & mask).count_ones();
324                let b = (i & !mask).count_ones();
325                if a < b {
326                    assert_eq!(a + 1, b, "{:x}", i);
327                } else if b < a {
328                    assert_eq!(a, b + 1, "{:x}", i);
329                } else {
330                    panic!("Odd parity shouldn't split in exactly half");
331                }
332            }
333        }
334    }
335
336    #[test]
337    fn empty_bit_average_u32() {
338        assert_eq!(None, bit_average_u32(0));
339    }
340
341    #[test]
342    fn singleton_bit_average_u32() {
343        for i in 0..32 {
344            assert_eq!(None, bit_average_u32(1 << i), "{:x}", i);
345        }
346    }
347
348    #[test]
349    fn parity_0_bit_average_u64() {
350        struct EvenParity(u64);
351
352        impl Iterator for EvenParity {
353            type Item = u64;
354            fn next(&mut self) -> Option<Self::Item> {
355                if self.0 == u64::max_value() {
356                    return None;
357                }
358                self.0 += 1;
359                while self.0.count_ones() & 1 != 0 {
360                    if self.0 == u64::max_value() {
361                        return None;
362                    }
363                    self.0 += 1;
364                }
365                Some(self.0)
366            }
367        }
368
369        let steps = 1000;
370        for i in 0..steps {
371            let pos = i * (u64::max_value() / steps);
372            for i in EvenParity(pos).take(steps as usize) {
373                let mask = (1 << bit_average_u64(i).unwrap_or(63)) - 1;
374                assert_eq!((i & mask).count_ones(), (i & !mask).count_ones(), "{:x}", i);
375            }
376        }
377    }
378
379    #[test]
380    fn parity_1_bit_average_u64() {
381        struct OddParity(u64);
382
383        impl Iterator for OddParity {
384            type Item = u64;
385            fn next(&mut self) -> Option<Self::Item> {
386                if self.0 == u64::max_value() {
387                    return None;
388                }
389                self.0 += 1;
390                while self.0.count_ones() & 1 == 0 {
391                    if self.0 == u64::max_value() {
392                        return None;
393                    }
394                    self.0 += 1;
395                }
396                Some(self.0)
397            }
398        }
399
400        let steps = 1000;
401        for i in 0..steps {
402            let pos = i * (u64::max_value() / steps);
403            for i in OddParity(pos).take(steps as usize) {
404                let mask = (1 << bit_average_u64(i).unwrap_or(63)) - 1;
405                let a = (i & mask).count_ones();
406                let b = (i & !mask).count_ones();
407                if a < b {
408                    assert_eq!(a + 1, b, "{:x}", i);
409                } else if b < a {
410                    assert_eq!(a, b + 1, "{:x}", i);
411                } else {
412                    panic!("Odd parity shouldn't split in exactly half");
413                }
414            }
415        }
416    }
417
418    #[test]
419    fn empty_bit_average_u64() {
420        assert_eq!(None, bit_average_u64(0));
421    }
422
423    #[test]
424    fn singleton_bit_average_u64() {
425        for i in 0..64 {
426            assert_eq!(None, bit_average_u64(1 << i), "{:x}", i);
427        }
428    }
429
430    #[test]
431    fn bit_average_agree_u32_u64() {
432        let steps = 1000;
433        for i in 0..steps {
434            let pos = i * (u32::max_value() / steps);
435            for i in pos..steps {
436                assert_eq!(
437                    bit_average_u32(i),
438                    bit_average_u64(i as u64).map(|n| n as u32),
439                    "{:x}",
440                    i
441                );
442            }
443        }
444    }
445
446    #[test]
447    fn specific_values() {
448        assert_eq!(Some(4), bit_average_u32(0b10110));
449        assert_eq!(Some(5), bit_average_u32(0b100010));
450        assert_eq!(None, bit_average_u32(0));
451        assert_eq!(None, bit_average_u32(1));
452
453        assert_eq!(Some(4), bit_average_u64(0b10110));
454        assert_eq!(Some(5), bit_average_u64(0b100010));
455        assert_eq!(None, bit_average_u64(0));
456        assert_eq!(None, bit_average_u64(1));
457    }
458}