aptos_bitvec/
lib.rs

1// Copyright (c) Aptos
2// SPDX-License-Identifier: Apache-2.0
3
4//! This library defines a BitVec struct that represents a bit vector.
5
6#[cfg(any(test, feature = "fuzzing"))]
7use proptest::{
8    arbitrary::{any, Arbitrary, StrategyFor},
9    collection::{vec, VecStrategy},
10    strategy::{Map, Strategy},
11};
12use serde::{de::Error, Deserialize, Deserializer, Serialize};
13use std::{
14    iter::FromIterator,
15    ops::{BitAnd, BitOr},
16};
17
18// Every u8 is used as a bucket of 8 bits. Total max buckets = 256 / 8 = 32.
19const BUCKET_SIZE: usize = 8;
20const MAX_BUCKETS: usize = 32;
21
22/// BitVec represents a bit vector that supports 4 operations:
23///
24/// 1. Marking a position as set.
25/// 2. Checking if a position is set.
26/// 3. Count set bits.
27/// 4. Get the index of the last set bit.
28///
29/// Internally, it stores a vector of u8's (as Vec<u8>).
30///
31/// * The first 8 positions of the bit vector are encoded in the first element of the vector, the
32///   next 8 are encoded in the second element, and so on.
33/// * Bits are read from left to right. For instance, in the following bitvec
34///   [0b0001_0000, 0b0000_0000, 0b0000_0000, 0b0000_0001], the 3rd and 31st positions are set.
35/// * Each bit of a u8 is set to 1 if the position is set and to 0 if it's not.
36/// * We only allow setting positions upto u8::MAX. As a result, the size of the inner vector is
37///   limited to 32 (= 256 / 8).
38/// * Once a bit has been set, it cannot be unset. As a result, the inner vector cannot shrink.
39/// * The positions can be set in any order.
40/// * A position can set more than once -- it remains set after the first time.
41///
42/// # Examples:
43/// ```
44/// use aptos_bitvec::BitVec;
45/// use std::ops::BitAnd;
46///
47/// let mut bv = BitVec::default();
48/// bv.set(2);
49/// bv.set(5);
50/// assert!(bv.is_set(2));
51/// assert!(bv.is_set(5));
52/// assert_eq!(false, bv.is_set(0));
53/// assert_eq!(bv.count_ones(), 2);
54/// assert_eq!(bv.last_set_bit(), Some(5));
55///
56/// // A bitwise AND of BitVec can be performed by using the `&` operator.
57/// let mut bv1 = BitVec::default();
58/// bv1.set(2);
59/// bv1.set(3);
60/// let mut bv2 = BitVec::default();
61/// bv2.set(2);
62/// let intersection = bv1.bitand(&bv2);
63/// assert!(intersection.is_set(2));
64/// assert_eq!(false, intersection.is_set(3));
65/// ```
66#[derive(Clone, Default, Debug, Eq, PartialEq, Serialize)]
67pub struct BitVec {
68    #[serde(with = "serde_bytes")]
69    inner: Vec<u8>,
70}
71
72impl BitVec {
73    fn with_capacity(num_buckets: usize) -> Self {
74        Self {
75            inner: Vec::with_capacity(num_buckets),
76        }
77    }
78
79    /// Sets the bit at position @pos.
80    pub fn set(&mut self, pos: u8) {
81        // This is optimised to: let bucket = pos >> 3;
82        let bucket: usize = pos as usize / BUCKET_SIZE;
83        if self.inner.len() <= bucket {
84            self.inner.resize(bucket + 1, 0);
85        }
86        // This is optimized to: let bucket_pos = pos | 0x07;
87        let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
88        self.inner[bucket] |= 0b1000_0000 >> bucket_pos as u8;
89    }
90
91    /// Checks if the bit at position @pos is set.
92    #[inline]
93    pub fn is_set(&self, pos: u8) -> bool {
94        // This is optimised to: let bucket = pos >> 3;
95        let bucket: usize = pos as usize / BUCKET_SIZE;
96        if self.inner.len() <= bucket {
97            return false;
98        }
99        // This is optimized to: let bucket_pos = pos | 0x07;
100        let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
101        (self.inner[bucket] & (0b1000_0000 >> bucket_pos as u8)) != 0
102    }
103
104    /// Return true if the BitVec is all zeros.
105    pub fn all_zeros(&self) -> bool {
106        self.inner.iter().all(|byte| *byte == 0)
107    }
108
109    /// Returns the number of set bits.
110    pub fn count_ones(&self) -> u32 {
111        self.inner.iter().map(|a| a.count_ones()).sum()
112    }
113
114    /// Returns the index of the last set bit.
115    pub fn last_set_bit(&self) -> Option<u8> {
116        self.inner
117            .iter()
118            .rev()
119            .enumerate()
120            .find(|(_, byte)| byte != &&0u8)
121            .map(|(i, byte)| {
122                (8 * (self.inner.len() - i) - byte.trailing_zeros() as usize - 1) as u8
123            })
124    }
125
126    /// Return an `Iterator` over all '1' bit indexes.
127    pub fn iter_ones(&self) -> impl Iterator<Item = u8> + '_ {
128        (0..=u8::MAX).filter(move |idx| self.is_set(*idx))
129    }
130}
131
132impl BitAnd for &BitVec {
133    type Output = BitVec;
134
135    /// Returns a new BitVec that is a bitwise AND of two BitVecs.
136    fn bitand(self, other: Self) -> Self::Output {
137        let len = std::cmp::min(self.inner.len(), other.inner.len());
138        let mut ret = BitVec::with_capacity(len);
139        for i in 0..len {
140            ret.inner.push(self.inner[i] & other.inner[i]);
141        }
142        ret
143    }
144}
145
146impl BitOr for &BitVec {
147    type Output = BitVec;
148
149    /// Returns a new BitVec that is a bitwise OR of two BitVecs.
150    fn bitor(self, other: Self) -> Self::Output {
151        let len = std::cmp::max(self.inner.len(), other.inner.len());
152        let mut ret = BitVec::with_capacity(len);
153        for i in 0..len {
154            let a = self.inner.get(i).copied().unwrap_or(0);
155            let b = other.inner.get(i).copied().unwrap_or(0);
156            ret.inner.push(a | b);
157        }
158        ret
159    }
160}
161
162impl FromIterator<u8> for BitVec {
163    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
164        let mut bitvec = Self::default();
165        for bit in iter {
166            bitvec.set(bit);
167        }
168        bitvec
169    }
170}
171
172// We impl custom deserialization to ensure that the length of inner vector does not exceed
173// 32 (= 256 / 8).
174impl<'de> Deserialize<'de> for BitVec {
175    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
176    where
177        D: Deserializer<'de>,
178    {
179        let v = serde_bytes::ByteBuf::deserialize(deserializer)?.into_vec();
180        if v.len() > MAX_BUCKETS {
181            return Err(D::Error::custom(format!("BitVec too long: {}", v.len())));
182        }
183        Ok(BitVec { inner: v })
184    }
185}
186
187#[cfg(any(test, feature = "fuzzing"))]
188impl Arbitrary for BitVec {
189    type Parameters = ();
190    type Strategy = Map<VecStrategy<StrategyFor<u8>>, fn(Vec<u8>) -> BitVec>;
191
192    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
193        vec(any::<u8>(), 0..=MAX_BUCKETS).prop_map(|inner| BitVec { inner })
194    }
195}
196
197#[cfg(test)]
198mod test {
199    use super::*;
200    use proptest::proptest;
201
202    #[test]
203    fn test_count_ones() {
204        let p0 = BitVec::default();
205        assert_eq!(p0.count_ones(), 0);
206        // 7 = b'0000111' and 240 = b'00001111'
207        let p1 = BitVec {
208            inner: vec![7u8, 15u8],
209        };
210        assert_eq!(p1.count_ones(), 7);
211
212        let p2 = BitVec {
213            inner: vec![7u8; MAX_BUCKETS],
214        };
215        assert_eq!(p2.count_ones(), 3 * MAX_BUCKETS as u32);
216
217        // 255 = b'11111111'
218        let p3 = BitVec {
219            inner: vec![255u8; MAX_BUCKETS],
220        };
221        assert_eq!(p3.count_ones(), 8 * MAX_BUCKETS as u32);
222
223        // 0 = b'00000000'
224        let p4 = BitVec {
225            inner: vec![0u8; MAX_BUCKETS],
226        };
227        assert_eq!(p4.count_ones(), 0);
228    }
229
230    #[test]
231    fn test_last_set_bit() {
232        let p0 = BitVec::default();
233        assert_eq!(p0.last_set_bit(), None);
234        // 224 = b'11100000'
235        let p1 = BitVec { inner: vec![224u8] };
236        assert_eq!(p1.inner.len(), 1);
237        assert_eq!(p1.last_set_bit(), Some(2));
238
239        // 128 = 0b1000_0000
240        let p2 = BitVec {
241            inner: vec![7u8, 128u8],
242        };
243        assert_eq!(p2.inner.len(), 2);
244        assert_eq!(p2.last_set_bit(), Some(8));
245
246        let p3 = BitVec {
247            inner: vec![255u8; MAX_BUCKETS],
248        };
249        assert_eq!(p3.inner.len(), MAX_BUCKETS);
250        assert_eq!(p3.last_set_bit(), Some(255));
251
252        let p4 = BitVec {
253            inner: vec![0u8; MAX_BUCKETS],
254        };
255        assert_eq!(p4.last_set_bit(), None);
256
257        // An extra test to ensure left to right encoding.
258        let mut p5 = BitVec {
259            inner: vec![0b0000_0001, 0b0100_0000],
260        };
261        assert_eq!(p5.last_set_bit(), Some(9));
262        assert!(p5.is_set(7));
263        assert!(p5.is_set(9));
264        assert!(!p5.is_set(0));
265
266        p5.set(10);
267        assert!(p5.is_set(10));
268        assert_eq!(p5.last_set_bit(), Some(10));
269        assert_eq!(p5.inner, vec![0b0000_0001, 0b0110_0000]);
270
271        let p6 = BitVec {
272            inner: vec![0b1000_0000],
273        };
274        assert_eq!(p6.inner.len(), 1);
275        assert_eq!(p6.last_set_bit(), Some(0));
276    }
277
278    #[test]
279    fn test_empty() {
280        let p = BitVec::default();
281        for i in 0..=std::u8::MAX {
282            assert!(!p.is_set(i));
283        }
284    }
285
286    #[test]
287    fn test_extremes() {
288        let mut p = BitVec::default();
289        p.set(std::u8::MAX);
290        p.set(0);
291        assert!(p.is_set(std::u8::MAX));
292        assert!(p.is_set(0));
293        for i in 1..std::u8::MAX {
294            assert!(!p.is_set(i));
295        }
296        assert_eq!(vec![0, u8::MAX], p.iter_ones().collect::<Vec<_>>());
297    }
298
299    #[test]
300    fn test_deserialization() {
301        // When the length is smaller than 128, it is encoded in the first byte.
302        // (see comments in BCS crate)
303        let mut bytes = [0u8; 47];
304        bytes[0] = 46;
305        assert!(bcs::from_bytes::<Vec<u8>>(&bytes).is_ok());
306        // However, 46 > MAX_BUCKET:
307        assert!(bcs::from_bytes::<BitVec>(&bytes).is_err());
308        let mut bytes = [0u8; 33];
309        bytes[0] = 32;
310        let bv = BitVec {
311            inner: Vec::from([0u8; 32].as_ref()),
312        };
313        assert_eq!(Ok(bv), bcs::from_bytes::<BitVec>(&bytes));
314    }
315
316    // Test for bitwise AND operation on 2 bitvecs.
317    proptest! {
318        #[test]
319        fn test_and(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
320            let intersection = bv1.bitand(&bv2);
321
322            assert!(intersection.count_ones() <= bv1.count_ones());
323            assert!(intersection.count_ones() <= bv2.count_ones());
324
325            for i in 0..=std::u8::MAX {
326                if bv1.is_set(i) && bv2.is_set(i) {
327                    assert!(intersection.is_set(i));
328                } else {
329                    assert!(!intersection.is_set(i));
330                }
331            }
332        }
333
334        #[test]
335        fn test_or(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
336            let union = bv1.bitor(&bv2);
337
338            assert!(union.count_ones() >= bv1.count_ones());
339            assert!(union.count_ones() >= bv2.count_ones());
340
341            for i in 0..=std::u8::MAX {
342                if bv1.is_set(i) || bv2.is_set(i) {
343                    assert!(union.is_set(i));
344                } else {
345                    assert!(!union.is_set(i));
346                }
347            }
348        }
349
350        #[test]
351        fn test_iter_ones(bv1 in any::<BitVec>()) {
352            assert_eq!(bv1.iter_ones().count(), bv1.count_ones() as usize);
353        }
354    }
355}