mmap_bitvec/
bitvec.rs

1use std::ops::Range;
2
3/// A basic bitvector trait that we implement for mmap
4pub trait BitVector {
5    /// Get the value at bit `i`
6    fn get(&self, i: usize) -> bool;
7    /// Set the value at bit `i`
8    fn set(&mut self, i: usize, x: bool);
9    /// Returns the size of the bitvector
10    fn size(&self) -> usize;
11
12    /// Returns the number of bits sets in the given range
13    fn rank(&self, r: Range<usize>) -> usize {
14        r.fold(0, |a, x| a + if self.get(x) { 1 } else { 0 })
15    }
16
17    /// Returns the position of the n-th set bit
18    fn select(&self, n: usize, start: usize) -> Option<usize> {
19        let mut bits_left = n;
20
21        for i in start..self.size() {
22            if self.get(i) {
23                bits_left -= 1;
24            }
25
26            if bits_left == 0 {
27                return Some(i);
28            }
29        }
30        None
31    }
32
33    /// Return all the bits in the given range as a `u128`. The input range `r` must span `<= 128`,
34    /// as the result is bitpacked into a `u128`.
35    ///
36    /// For example, an input range of `(0, 7)` will set the first 8 bits of the returned `u128` to the result of `self.get(0, 1, ... 7)`.
37    fn get_range(&self, r: Range<usize>) -> u128 {
38        if r.end - r.start > 128 {
39            panic!("Range too large (>128)")
40        } else if r.end > self.size() {
41            panic!("Range ends outside of BitVec")
42        }
43
44        let mut bvs = 0;
45        let mut bit_pos = 127;
46        for i in r {
47            if self.get(i) {
48                bvs += 1 << bit_pos;
49            };
50            bit_pos -= 1;
51        }
52        bvs
53    }
54
55    /// Sets all the bits in the given range from the given `u128`
56    fn set_range(&mut self, r: Range<usize>, x: u128) {
57        let mut cur = x;
58        for i in r.rev() {
59            self.set(i, cur & 1 == 1);
60            cur >>= 1;
61        }
62    }
63
64    /// Sets all the bit in the given range to false
65    fn clear_range(&mut self, r: Range<usize>) {
66        for i in r.rev() {
67            self.set(i, false);
68        }
69    }
70}
71
72macro_rules! impl_bitvector {
73    ( $type:ty, $type_size:expr ) => {
74        impl BitVector for $type {
75            fn get(&self, i: usize) -> bool {
76                if i > $type_size - 1 {
77                    panic!("Invalid bit vector index");
78                }
79                (self & 1 << ($type_size - i)) > 0
80            }
81
82            fn set(&mut self, i: usize, x: bool) {
83                if x {
84                    self.clone_from(&(*self | (1 << ($type_size - i))));
85                } else {
86                    self.clone_from(&(*self & !(1 << ($type_size - i))));
87                }
88            }
89
90            fn size(&self) -> usize {
91                $type_size
92            }
93        }
94    };
95}
96
97impl_bitvector!(u8, 8);
98impl_bitvector!(u16, 16);
99impl_bitvector!(u32, 32);
100impl_bitvector!(u64, 64);
101impl_bitvector!(u128, 128);
102
103impl BitVector for &[u8] {
104    fn get(&self, i: usize) -> bool {
105        if i / 8 >= self.size() {
106            panic!("Invalid bit vector index");
107        }
108        self[i / 8] >> (8 - i % 8) & 1 == 1
109    }
110
111    fn set(&mut self, _: usize, _: bool) {
112        panic!("Can not set bits on a non-mut slice");
113    }
114
115    fn size(&self) -> usize {
116        self.len() / 8
117    }
118}
119
120impl BitVector for Vec<u8> {
121    fn get(&self, i: usize) -> bool {
122        if i / 8 >= self.len() {
123            panic!("Invalid bit vector index");
124        }
125        self[i / 8] >> (8 - i % 8) & 1 == 1
126    }
127
128    fn set(&mut self, i: usize, x: bool) {
129        if i / 8 >= self.len() {
130            panic!("Invalid bit vector index");
131        }
132        if x {
133            self[i / 8] |= 1 << (8 - i);
134        } else {
135            self[i / 8] &= !(1 << (8 - i));
136        }
137    }
138
139    fn size(&self) -> usize {
140        self.len() / 8
141    }
142}