aptos_bitvec_link/
lib.rs

1// Copyright (c) Aptos
2// SPDX-License-Identifier: Apache-2.0
3
4//! This library defines a BitVec struct that represents a bit vector.
5
6#[cfg(any(test, feature = "fuzzing"))]
7use proptest::{
8    arbitrary::{any, Arbitrary, StrategyFor},
9    collection::{vec, VecStrategy},
10    strategy::{Map, Strategy},
11};
12use serde::{de::Error, Deserialize, Deserializer, Serialize};
13use std::{
14    iter::FromIterator,
15    ops::{BitAnd, BitOr},
16};
17
18// Every u8 is used as a bucket of 8 bits. Total max buckets = 65536 / 8 = 8196.
19const BUCKET_SIZE: usize = 8;
20const MAX_BUCKETS: usize = 8192;
21
22/// BitVec represents a bit vector that supports 4 operations:
23///
24/// 1. Marking a position as set.
25/// 2. Checking if a position is set.
26/// 3. Count set bits.
27/// 4. Get the index of the last set bit.
28///
29/// Internally, it stores a vector of u8's (as Vec<u8>).
30///
31/// * The first 8 positions of the bit vector are encoded in the first element of the vector, the
32///   next 8 are encoded in the second element, and so on.
33/// * Bits are read from left to right. For instance, in the following bitvec
34///   [0b0001_0000, 0b0000_0000, 0b0000_0000, 0b0000_0001], the 3rd and 31st positions are set.
35/// * Each bit of a u8 is set to 1 if the position is set and to 0 if it's not.
36/// * We only allow setting positions upto u16::MAX. As a result, the size of the inner vector is
37///   limited to 8192 (= 65536 / 8).
38/// * Once a bit has been set, it cannot be unset. As a result, the inner vector cannot shrink.
39/// * The positions can be set in any order.
40/// * A position can set more than once -- it remains set after the first time.
41///
42/// # Examples:
43/// ```
44/// use aptos_bitvec::BitVec;
45/// use std::ops::BitAnd;
46///
47/// let mut bv = BitVec::default();
48/// bv.set(2);
49/// bv.set(5);
50/// assert!(bv.is_set(2));
51/// assert!(bv.is_set(5));
52/// assert_eq!(false, bv.is_set(0));
53/// assert_eq!(bv.count_ones(), 2);
54/// assert_eq!(bv.last_set_bit(), Some(5));
55///
56/// // A bitwise AND of BitVec can be performed by using the `&` operator.
57/// let mut bv1 = BitVec::default();
58/// bv1.set(2);
59/// bv1.set(3);
60/// let mut bv2 = BitVec::default();
61/// bv2.set(2);
62/// let intersection = bv1.bitand(&bv2);
63/// assert!(intersection.is_set(2));
64/// assert_eq!(false, intersection.is_set(3));
65/// ```
66#[derive(Clone, Default, Debug, Eq, PartialEq, Serialize)]
67pub struct BitVec {
68    #[serde(with = "serde_bytes")]
69    inner: Vec<u8>,
70}
71
72impl BitVec {
73    fn with_capacity(num_buckets: usize) -> Self {
74        Self {
75            inner: Vec::with_capacity(num_buckets),
76        }
77    }
78
79    /// Initialize with buckets that can fit in num_bits.
80    pub fn with_num_bits(num_bits: u16) -> Self {
81        Self {
82            inner: vec![0; Self::required_buckets(num_bits)],
83        }
84    }
85
86    /// Sets the bit at position @pos.
87    pub fn set(&mut self, pos: u16) {
88        // This is optimised to: let bucket = pos >> 3;
89        let bucket: usize = pos as usize / BUCKET_SIZE;
90        if self.inner.len() <= bucket {
91            self.inner.resize(bucket + 1, 0);
92        }
93        // This is optimized to: let bucket_pos = pos | 0x07;
94        let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
95        self.inner[bucket] |= 0b1000_0000 >> bucket_pos as u8;
96    }
97
98    /// Checks if the bit at position @pos is set.
99    #[inline]
100    pub fn is_set(&self, pos: u16) -> bool {
101        // This is optimised to: let bucket = pos >> 3;
102        let bucket: usize = pos as usize / BUCKET_SIZE;
103        if self.inner.len() <= bucket {
104            return false;
105        }
106        // This is optimized to: let bucket_pos = pos | 0x07;
107        let bucket_pos = pos as usize - (bucket * BUCKET_SIZE);
108        (self.inner[bucket] & (0b1000_0000 >> bucket_pos as u8)) != 0
109    }
110
111    /// Return true if the BitVec is all zeros.
112    pub fn all_zeros(&self) -> bool {
113        self.inner.iter().all(|byte| *byte == 0)
114    }
115
116    /// Returns the number of set bits.
117    pub fn count_ones(&self) -> u32 {
118        self.inner.iter().map(|a| a.count_ones()).sum()
119    }
120
121    /// Returns the index of the last set bit.
122    pub fn last_set_bit(&self) -> Option<u16> {
123        self.inner
124            .iter()
125            .rev()
126            .enumerate()
127            .find(|(_, byte)| byte != &&0u8)
128            .map(|(i, byte)| {
129                (8 * (self.inner.len() - i) - byte.trailing_zeros() as usize - 1) as u16
130            })
131    }
132
133    /// Return an `Iterator` over all '1' bit indexes.
134    pub fn iter_ones(&self) -> impl Iterator<Item = usize> + '_ {
135        (0..self.inner.len() * BUCKET_SIZE).filter(move |idx| self.is_set(*idx as u16))
136    }
137
138    /// Return the number of buckets.
139    pub fn num_buckets(&self) -> usize {
140        self.inner.len()
141    }
142
143    /// Number of buckets require for num_bits.
144    pub fn required_buckets(num_bits: u16) -> usize {
145        num_bits
146            .checked_sub(1)
147            .map_or(0, |pos| pos as usize / BUCKET_SIZE + 1)
148    }
149}
150
151impl BitAnd for &BitVec {
152    type Output = BitVec;
153
154    /// Returns a new BitVec that is a bitwise AND of two BitVecs.
155    fn bitand(self, other: Self) -> Self::Output {
156        let len = std::cmp::min(self.inner.len(), other.inner.len());
157        let mut ret = BitVec::with_capacity(len);
158        for i in 0..len {
159            ret.inner.push(self.inner[i] & other.inner[i]);
160        }
161        ret
162    }
163}
164
165impl BitOr for &BitVec {
166    type Output = BitVec;
167
168    /// Returns a new BitVec that is a bitwise OR of two BitVecs.
169    fn bitor(self, other: Self) -> Self::Output {
170        let len = std::cmp::max(self.inner.len(), other.inner.len());
171        let mut ret = BitVec::with_capacity(len);
172        for i in 0..len {
173            let a = self.inner.get(i).copied().unwrap_or(0);
174            let b = other.inner.get(i).copied().unwrap_or(0);
175            ret.inner.push(a | b);
176        }
177        ret
178    }
179}
180
181impl FromIterator<u8> for BitVec {
182    fn from_iter<T: IntoIterator<Item = u8>>(iter: T) -> Self {
183        let mut bitvec = Self::default();
184        for bit in iter {
185            bitvec.set(bit as u16);
186        }
187        bitvec
188    }
189}
190
191impl From<Vec<u8>> for BitVec {
192    fn from(raw_bytes: Vec<u8>) -> Self {
193        assert!(raw_bytes.len() <= MAX_BUCKETS);
194        Self { inner: raw_bytes }
195    }
196}
197
198impl From<BitVec> for Vec<u8> {
199    fn from(bitvec: BitVec) -> Self {
200        bitvec.inner
201    }
202}
203
204impl From<Vec<bool>> for BitVec {
205    fn from(bits: Vec<bool>) -> Self {
206        assert!(bits.len() <= MAX_BUCKETS * BUCKET_SIZE);
207        let mut bitvec = Self::with_num_bits(bits.len() as u16);
208        for (index, b) in bits.iter().enumerate() {
209            if *b {
210                bitvec.set(index as u16);
211            }
212        }
213        bitvec
214    }
215}
216
217impl<'de> Deserialize<'de> for BitVec {
218    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
219    where
220        D: Deserializer<'de>,
221    {
222        #[derive(Deserialize)]
223        #[serde(rename = "BitVec")]
224        struct RawData {
225            #[serde(with = "serde_bytes")]
226            inner: Vec<u8>,
227        }
228        let v = RawData::deserialize(deserializer)?.inner;
229        if v.len() > MAX_BUCKETS {
230            return Err(D::Error::custom(format!("BitVec too long: {}", v.len())));
231        }
232        Ok(BitVec { inner: v })
233    }
234}
235
236#[cfg(any(test, feature = "fuzzing"))]
237impl Arbitrary for BitVec {
238    type Parameters = ();
239    type Strategy = Map<VecStrategy<StrategyFor<u8>>, fn(Vec<u8>) -> BitVec>;
240
241    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
242        vec(any::<u8>(), 0..=MAX_BUCKETS).prop_map(|inner| BitVec { inner })
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use super::*;
249    use proptest::proptest;
250
251    #[test]
252    fn test_count_ones() {
253        let p0 = BitVec::default();
254        assert_eq!(p0.count_ones(), 0);
255        // 7 = b'0000111' and 240 = b'00001111'
256        let p1 = BitVec {
257            inner: vec![7u8, 15u8],
258        };
259        assert_eq!(p1.count_ones(), 7);
260
261        let p2 = BitVec {
262            inner: vec![7u8; MAX_BUCKETS],
263        };
264        assert_eq!(p2.count_ones(), 3 * MAX_BUCKETS as u32);
265
266        // 255 = b'11111111'
267        let p3 = BitVec {
268            inner: vec![255u8; MAX_BUCKETS],
269        };
270        assert_eq!(p3.count_ones(), 8 * MAX_BUCKETS as u32);
271
272        // 0 = b'00000000'
273        let p4 = BitVec {
274            inner: vec![0u8; MAX_BUCKETS],
275        };
276        assert_eq!(p4.count_ones(), 0);
277    }
278
279    #[test]
280    fn test_last_set_bit() {
281        let p0 = BitVec::default();
282        assert_eq!(p0.last_set_bit(), None);
283        // 224 = b'11100000'
284        let p1 = BitVec { inner: vec![224u8] };
285        assert_eq!(p1.inner.len(), 1);
286        assert_eq!(p1.last_set_bit(), Some(2));
287
288        // 128 = 0b1000_0000
289        let p2 = BitVec {
290            inner: vec![7u8, 128u8],
291        };
292        assert_eq!(p2.inner.len(), 2);
293        assert_eq!(p2.last_set_bit(), Some(8));
294
295        let p3 = BitVec {
296            inner: vec![255u8; MAX_BUCKETS],
297        };
298        assert_eq!(p3.inner.len(), MAX_BUCKETS);
299        assert_eq!(p3.last_set_bit(), Some(65535));
300
301        let p4 = BitVec {
302            inner: vec![0u8; MAX_BUCKETS],
303        };
304        assert_eq!(p4.last_set_bit(), None);
305
306        // An extra test to ensure left to right encoding.
307        let mut p5 = BitVec {
308            inner: vec![0b0000_0001, 0b0100_0000],
309        };
310        assert_eq!(p5.last_set_bit(), Some(9));
311        assert!(p5.is_set(7));
312        assert!(p5.is_set(9));
313        assert!(!p5.is_set(0));
314
315        p5.set(10);
316        assert!(p5.is_set(10));
317        assert_eq!(p5.last_set_bit(), Some(10));
318        assert_eq!(p5.inner, vec![0b0000_0001, 0b0110_0000]);
319
320        let p6 = BitVec {
321            inner: vec![0b1000_0000],
322        };
323        assert_eq!(p6.inner.len(), 1);
324        assert_eq!(p6.last_set_bit(), Some(0));
325    }
326
327    #[test]
328    fn test_empty() {
329        let p = BitVec::default();
330        for i in 0..=u16::MAX {
331            assert!(!p.is_set(i));
332        }
333    }
334
335    #[test]
336    fn test_extremes() {
337        let mut p = BitVec::default();
338        p.set(u16::MAX);
339        p.set(0);
340        assert!(p.is_set(u16::MAX));
341        assert!(p.is_set(0));
342        for i in 1..u16::MAX {
343            assert!(!p.is_set(i));
344        }
345        assert_eq!(
346            vec![0, u16::MAX as usize],
347            p.iter_ones().collect::<Vec<_>>()
348        );
349    }
350
351    #[test]
352    fn test_conversion() {
353        let bitmaps = vec![
354            false, true, true, false, false, true, true, false, true, true, true,
355        ];
356        let bitvec = BitVec::from(bitmaps.clone());
357        for (index, is_set) in bitmaps.into_iter().enumerate() {
358            assert_eq!(bitvec.is_set(index as u16), is_set);
359        }
360    }
361
362    #[test]
363    fn test_deserialization() {
364        let raw = vec![0u8; 9000];
365        let bytes = bcs::to_bytes(&raw).unwrap();
366        assert!(bcs::from_bytes::<Vec<u8>>(&bytes).is_ok());
367        // 9000 > MAX_BUCKET:
368        assert!(bcs::from_bytes::<BitVec>(&bytes).is_err());
369        let mut bytes = [0u8; 33];
370        bytes[0] = 32;
371        let bv = BitVec {
372            inner: Vec::from([0u8; 32].as_ref()),
373        };
374        assert_eq!(Ok(bv), bcs::from_bytes::<BitVec>(&bytes));
375    }
376
377    // Test for bitwise AND operation on 2 bitvecs.
378    proptest! {
379        #[test]
380        fn test_and(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
381            let intersection = bv1.bitand(&bv2);
382
383            assert!(intersection.count_ones() <= bv1.count_ones());
384            assert!(intersection.count_ones() <= bv2.count_ones());
385
386            for i in 0..=u16::MAX {
387                if bv1.is_set(i) && bv2.is_set(i) {
388                    assert!(intersection.is_set(i));
389                } else {
390                    assert!(!intersection.is_set(i));
391                }
392            }
393        }
394
395        #[test]
396        fn test_or(bv1 in any::<BitVec>(), bv2 in any::<BitVec>()) {
397            let union = bv1.bitor(&bv2);
398
399            assert!(union.count_ones() >= bv1.count_ones());
400            assert!(union.count_ones() >= bv2.count_ones());
401
402            for i in 0..=u16::MAX {
403                if bv1.is_set(i) || bv2.is_set(i) {
404                    assert!(union.is_set(i));
405                } else {
406                    assert!(!union.is_set(i));
407                }
408            }
409        }
410
411        #[test]
412        fn test_iter_ones(bv1 in any::<BitVec>()) {
413            assert_eq!(bv1.iter_ones().count(), bv1.count_ones() as usize);
414        }
415
416        #[test]
417        fn test_serde_roundtrip(bits in vec(any::<bool>(), 0..u16::MAX as usize)) {
418            let bitvec = BitVec::from(bits);
419            let bytes = serde_json::to_vec(&bitvec).unwrap();
420            let back = serde_json::from_slice(&bytes).unwrap();
421            assert_eq!(bitvec, back);
422        }
423
424    }
425}