layered_bitset/
layered.rs

1use crate::{BitSet, BitSetMut};
2
3#[derive(Debug)]
4pub struct Layered<T, B, const N: usize> {
5    top: T,
6    bottom: [B; N],
7}
8
9impl<T, B, const N: usize> Default for Layered<T, B, N>
10where
11    T: Default + BitSet,
12    B: Default + BitSet,
13{
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl<T, B, const N: usize> Layered<T, B, N>
20where
21    T: BitSet,
22    B: BitSet,
23{
24    fn new() -> Self
25    where
26        T: Default,
27        B: Default,
28    {
29        assert_eq!(T::UPPER_BOUND, N as u32);
30        assert_eq!(T::UPPER_BOUND as usize, N);
31
32        use core::mem::MaybeUninit;
33        let bottom = unsafe {
34            let mut array: [MaybeUninit<B>; N] = MaybeUninit::uninit().assume_init();
35            for item in &mut array {
36                // Leaks all previously written elements on panic. It is safe though.
37                core::ptr::write(item, MaybeUninit::new(B::default()));
38            }
39            (&array as *const _ as *const [B; N]).read()
40        };
41
42        Layered {
43            top: T::default(),
44            bottom,
45        }
46    }
47}
48
49impl<T, B, const N: usize> BitSet for Layered<T, B, N>
50where
51    T: BitSet,
52    B: BitSet,
53{
54    const UPPER_BOUND: u32 = B::UPPER_BOUND * N as u32;
55
56    fn get(&self, index: u32) -> bool {
57        assert!(index < Self::UPPER_BOUND);
58
59        let t = index / B::UPPER_BOUND;
60        let b = index % B::UPPER_BOUND;
61
62        self.bottom[t as usize].get(b)
63    }
64
65    fn find_set(&self, lower_bound: u32) -> Option<u32> {
66        assert!(lower_bound < Self::UPPER_BOUND);
67
68        let t = lower_bound / B::UPPER_BOUND;
69        let b = lower_bound % B::UPPER_BOUND;
70
71        if b == 0 {
72            let t = self.top.find_set(t)?;
73            let b = self.bottom[t as usize].find_set(0)?;
74            Some(t * B::UPPER_BOUND + b)
75        } else {
76            if self.top.get(t) {
77                if let Some(b) = self.bottom[t as usize].find_set(b) {
78                    return Some(t * B::UPPER_BOUND + b);
79                }
80            }
81
82            let t = self.top.find_set(t + 1)?;
83            let b = self.bottom[t as usize].find_set(0)?;
84            Some(t * B::UPPER_BOUND + b)
85        }
86    }
87}
88
89impl<T, B, const N: usize> BitSetMut for Layered<T, B, N>
90where
91    T: BitSetMut,
92    B: BitSetMut,
93{
94    fn set(&mut self, index: u32, bit: bool) {
95        assert!(index < Self::UPPER_BOUND);
96        let t = index / B::UPPER_BOUND;
97        let u = index % B::UPPER_BOUND;
98
99        if bit {
100            if !self.top.get(t) {
101                self.top.set(t, true);
102            }
103            self.bottom[t as usize].set(u, true)
104        } else {
105            if self.top.get(t) {
106                self.bottom[t as usize].set(u, false);
107                if self.bottom[t as usize].is_empty() {
108                    self.top.set(t, false);
109                }
110            }
111        }
112    }
113}