1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use std::ops::Range;

/// A basic bitvector trait that we implement for mmap
pub trait BitVector {
    /// Get the value at bit `i`
    fn get(&self, i: usize) -> bool;
    /// Set the value at bit `i`
    fn set(&mut self, i: usize, x: bool);
    /// Returns the size of the bitvector
    fn size(&self) -> usize;

    /// Returns the number of bits sets in the given range
    fn rank(&self, r: Range<usize>) -> usize {
        r.fold(0, |a, x| a + if self.get(x) { 1 } else { 0 })
    }

    /// Returns the position of the n-th set bit
    fn select(&self, n: usize, start: usize) -> Option<usize> {
        let mut bits_left = n;

        for i in start..self.size() {
            if self.get(i) {
                bits_left -= 1;
            }

            if bits_left == 0 {
                return Some(i);
            }
        }
        None
    }

    /// Return all the bits in the given range as a `u128`. The input range `r` must span `<= 128`,
    /// as the result is bitpacked into a `u128`.
    ///
    /// For example, an input range of `(0, 7)` will set the first 8 bits of the returned `u128` to the result of `self.get(0, 1, ... 7)`.
    fn get_range(&self, r: Range<usize>) -> u128 {
        if r.end - r.start > 128 {
            panic!("Range too large (>128)")
        } else if r.end > self.size() {
            panic!("Range ends outside of BitVec")
        }

        let mut bvs = 0;
        let mut bit_pos = 127;
        for i in r {
            if self.get(i) {
                bvs += 1 << bit_pos;
            };
            bit_pos -= 1;
        }
        bvs
    }

    /// Sets all the bits in the given range from the given `u128`
    fn set_range(&mut self, r: Range<usize>, x: u128) {
        let mut cur = x;
        for i in r.rev() {
            self.set(i, cur & 1 == 1);
            cur >>= 1;
        }
    }

    /// Sets all the bit in the given range to false
    fn clear_range(&mut self, r: Range<usize>) {
        for i in r.rev() {
            self.set(i, false);
        }
    }
}

macro_rules! impl_bitvector {
    ( $type:ty, $type_size:expr ) => {
        impl BitVector for $type {
            fn get(&self, i: usize) -> bool {
                if i > $type_size - 1 {
                    panic!("Invalid bit vector index");
                }
                (self & 1 << ($type_size - i)) > 0
            }

            fn set(&mut self, i: usize, x: bool) {
                if x {
                    self.clone_from(&(*self | (1 << ($type_size - i))));
                } else {
                    self.clone_from(&(*self & !(1 << ($type_size - i))));
                }
            }

            fn size(&self) -> usize {
                $type_size
            }
        }
    };
}

impl_bitvector!(u8, 8);
impl_bitvector!(u16, 16);
impl_bitvector!(u32, 32);
impl_bitvector!(u64, 64);
impl_bitvector!(u128, 128);

impl BitVector for &[u8] {
    fn get(&self, i: usize) -> bool {
        if i / 8 >= self.size() {
            panic!("Invalid bit vector index");
        }
        self[i / 8] >> (8 - i % 8) & 1 == 1
    }

    fn set(&mut self, _: usize, _: bool) {
        panic!("Can not set bits on a non-mut slice");
    }

    fn size(&self) -> usize {
        self.len() / 8
    }
}

impl BitVector for Vec<u8> {
    fn get(&self, i: usize) -> bool {
        if i / 8 >= self.len() {
            panic!("Invalid bit vector index");
        }
        self[i / 8] >> (8 - i % 8) & 1 == 1
    }

    fn set(&mut self, i: usize, x: bool) {
        if i / 8 >= self.len() {
            panic!("Invalid bit vector index");
        }
        if x {
            self[i / 8] |= 1 << (8 - i);
        } else {
            self[i / 8] &= !(1 << (8 - i));
        }
    }

    fn size(&self) -> usize {
        self.len() / 8
    }
}